从pytorch数据集返回索引:用于更改_getitem__的函数会导致元类冲突

2024-04-25 11:56:58 发布

您现在位置:Python中文网/ 问答频道 /正文

我有多个类(用于不同的数据集)继承自pytorch的Dataset类。它们有一个总体结构,如下所示:

from torch.utils.data import Dataset

class SomeDataset(Dataset):

    def __init__(self, data, labels):
        super(SomeDataset, self).__init__()
        self.data = data
        self.labels = labels
        self.__name__ = 'SomeDataset'

    def __getitem__(self, index):
        return {'data': self.data[index], 'label': self.labels[index]}

    def __len__(self):
        return len(data)

最近,我意识到在批处理时跟踪传递到Dataloader中的标签是有益的,因此在谷歌搜索如何做到这一点时,我遇到了this thread,这就是我修改代码以编写此函数的地方:

def return_indices(dataset_class):
    
    def __getitem__(self, index):
        return {'index':1, **dataset_class.__getitem__(self, index)}

    return type(dataset_class.__name__, (dataset_class, ), {'__getitem__': __getitem__})

我以前从未见过type像这样使用,但在谷歌搜索之后,它变得有意义,所以我尝试了一下。不幸的是,这导致了以下错误:

TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass of the metaclasses of all its bases

这导致了更多的谷歌搜索,尽管我开始理解元类是什么以及元类是如何使用的,但我仍然无法找出这种方法的错误所在,或者如何解决它——我开始认为,也许将此功能重写到我的dataset类中会更容易,而不是使用一些整洁的包装器来为我实现。有人能说出我所缺少的东西吗


Tags: ofnameselfdataindexlabelslenreturn
1条回答
网友
1楼 · 发布于 2024-04-25 11:56:58

只要这样做:

def return_indices(dataset_class):
    
    def __getitem__(self, index):
        return {'index':1, **dataset_class.__getitem__(self, index)}
    metacls = type(dataset_class)
    return metacls(dataset_class.__name__, (dataset_class, ), {'__getitem__': __getitem__})

发生了什么:正如您所发现的,对type的3参数调用是一种在Python中以编程方式创建新类的方法,而不需要“class”语句及其主体

但是type是“基本元类”——虽然它的实例是普通类,但它也“硬编码”了您正在创建的类的元类——相反,使用class语句将使Python在您正在创建的类的基础中搜索合适的元类

只需使用派生类元类(可以通过类型的单参数形式获得,如上所述,也可以通过类的__class__属性获得,如dataset_class.__class__

使用它作为类型的可调用替代将使其本身成为元类,并且应该可以正常工作

NB:由于元类还有一些机制,比如__prepare__,仅仅调用元类而不是type并不总是有效的-正确的通用方法是调用types.prepare_classtypes.new_class以及具有回调以执行类语句体中发生的类体的执行。大多数情况下不需要这样做

相关问题 更多 >

    热门问题