PyTorch,根据数据列中的标签选择批次

2024-04-26 05:06:17 发布

您现在位置:Python中文网/ 问答频道 /正文

我有这样一个数据集:

^{tb1}$

每个标记中的条目数并不总是相同的

我的目标是只加载带有一个或多个特定标记的数据,这样我就可以只为一个小批量获取tag1中的条目,然后如果我设置batch_size=1,则为另一个小批量获取tag2。或者,例如tag1tag2,如果我设置batch_size=2

到目前为止,我的代码完全忽略了tag标签,只是随机选择批次

我构建了如下数据集:

# features is a matrix with all the features columns through all rows
# target is a vector with the target column through all rows
featuresTrain, targetTrain = projutils.get_data(train=True, config=config)
train = torch.utils.data.TensorDataset(featuresTrain, targetTrain)
train_loader = make_loader(train, batch_size=config.batch_size)

我的加载程序(一般)如下所示:

def make_loader(dataset, batch_size):
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=batch_size, 
                                     shuffle=True,
                                     pin_memory=True,
                                     num_workers=8)
return loader

然后我就这样训练:

for epoch in range(config.epochs):
    for _, (features, target) in enumerate(loader):
        loss = train_batch(features, target, model, optimizer, criterion)

以及{}:

def train_batch(features, target, model, optimizer, criterion):
features, target = features.to(device), target.to(device)

# Forward pass ➡
outputs = model(features)
loss = criterion(outputs, target
return loss

Tags: 数据configtruetargetdatasizemodelbatch
1条回答
网友
1楼 · 发布于 2024-04-26 05:06:17

这是一个简单的数据集,大致实现了我所能说的最好的特性

class CustomDataset(data.Dataset):
    def __init__(self,featuresTrain,targetsTrain,tagsTrain,sample_equally = False):
       # self.tags should be a tensor in k-hot encoding form so a 2D tensor, 
       self.tags = tagsTrain
       self.x = featuresTrain
       self.y = targetsTrain
       self.unique_tagsets = None
       self.sample_equally = sample_equally

       # self.active tags is a 1D k-hot encoding vector
       self.active_tags = self.get_random_tag_set()
       
    
    def get_random_tag_set(self):
        # gets all unique sets of tags and returns one randomly
        if self.unique_tagsets is None:
             self.unique_tagsets = self.tags.unique(dim = 0)
        if self.sample_equally:
             rand_idx = torch.randint(len(self.unique_tagsets),[1])[1].detatch().int()
             return self.unique_tagsets[rand_idx]
        else:
            rand_idx = torch.randint(len(self.tags),[1])[1].detatch().int()
            return self.tags[rand_idx]

    def set_tags(self,tags):
       # specifies the set of tags that must be present for a datum to be selected
        self.active_tags = tags

    def __getitem__(self,index):
        # get all indices of elements with self.active_tags
        indices = torch.where(self.tags == self.active_tags)[0]

        # we select an index based on the indices of the elements that have the tag set
        idx = indices[index % len(indices)]

        item = self.x[idx], self.y[idx]
        return item

    def __len__(self):
        return len(self.y)

此数据集随机选择一组标记。然后,每次调用__getitem__()时,它都使用指定的索引从具有标记集的数据元素中进行选择。您可以在每个小批量后调用set_tags()get_random_tag_set()然后set_tags(),或者您希望更改标记集的频率,或者您可以自己手动指定标记集。数据集继承自torch.data.Dataset,因此您应该能够将if与torch.data.Dataloader一起使用,而无需修改

您可以使用sample_equally指定是否要根据每个标记集的流行程度对其进行采样,或者是否要对所有标记集进行同等采样,而不管该标记集有多少个元素

简而言之,这个数据集的边缘有点粗糙,但应该允许您使用相同的标记集对所有批次进行采样。主要缺点是每个元素可能会在每个批次中取样不止一次

对于初始编码,假设开始时每个数据示例都有一个标记列表,因此tags是一个列表列表,每个子列表包含标记。以下代码将此转换为k-hot编码,因此您可以:

def to_k_hot(tags):
  all_tags = []
  for ex in tags:
    for tag in ex:
        all_tags.append(tag)
  unique_tags = list(set(all_tags)) # remove duplicates

  tagsTrain = torch.zeros([len(tags),len(unique_tags)]): 
  for i in range(len(tags)): # index through all examples
    for j in range(len(unique_tags)): # index through all unique_tags
        if unique_tags[j] in tags[i]:
             tagsTrain[i,j] = 1

  return tagsTrain

例如,假设数据集有以下标记:

tags = [ [tag1],
         [tag1,tag2],
         [tag3],
         [tag2],
         [],
         [tag1,tag2,tag3] ]

调用to_k_hot(tags)将返回:

tensor([1,0,0],
       [1,1,0],
       [0,0,1],
       [0,1,0],
       [0,0,0],
       [1,1,1]])

相关问题 更多 >

    热门问题