PyTorch:来自两个数据集的成批训练集

2024-04-24 00:43:46 发布

您现在位置:Python中文网/ 问答频道 /正文

因此,我需要为我的网络构建训练集,其中N/2来自第一个数据集,N/2来自另一个数据集。我考虑过使用Dataloader的batch_sampler参数,并创建了这个类:

def chunk(indices, chunk_size):
    return torch.split(torch.tensor(indices), chunk_size)

class MixedBatch(Sampler):
    def __init__(self, dataset, batch_size, split_index):
        self.first_half_indices = list(range(split_index))
        self.second_half_indices = list(range(split_index, len(dataset)))
        self.batch_size = batch_size

    def __iter__(self):
        random.shuffle(self.first_half_indices)
        random.shuffle(self.second_half_indices)
        first_half_batches = chunk(self.first_half_indices, self.batch_size//2)
        second_half_batches = chunk(self.second_half_indices, self.batch_size//2)

        mixed = []
        for i, j in zip(first_half_batches, second_half_batches):
            mixed.append((torch.cat((i, j))))
        mixed = torch.cat(mixed)

        combined = list(torch.split(mixed, self.batch_size))

        # combined = list(first_half_batches + second_half_batches)
        combined = [batch.tolist() for batch in combined]
        random.shuffle(combined)
        return iter(combined)

    def __len__(self):
        return (len(self.first_half_indices) + len(self.second_half_indices)) // self.batch_size

但它返回了错误:AttributeError:“SequentialSampler”对象没有属性“set_epoch” 有没有其他方法可以做到这一点,可以使用sampler=DistributeSampler

非常感谢


Tags: selfsizedefbatchbatchestorchlistfirst