因此,我需要为我的网络构建训练集,其中N/2来自第一个数据集,N/2来自另一个数据集。我考虑过使用Dataloader的batch_sampler参数,并创建了这个类:
def chunk(indices, chunk_size):
return torch.split(torch.tensor(indices), chunk_size)
class MixedBatch(Sampler):
def __init__(self, dataset, batch_size, split_index):
self.first_half_indices = list(range(split_index))
self.second_half_indices = list(range(split_index, len(dataset)))
self.batch_size = batch_size
def __iter__(self):
random.shuffle(self.first_half_indices)
random.shuffle(self.second_half_indices)
first_half_batches = chunk(self.first_half_indices, self.batch_size//2)
second_half_batches = chunk(self.second_half_indices, self.batch_size//2)
mixed = []
for i, j in zip(first_half_batches, second_half_batches):
mixed.append((torch.cat((i, j))))
mixed = torch.cat(mixed)
combined = list(torch.split(mixed, self.batch_size))
# combined = list(first_half_batches + second_half_batches)
combined = [batch.tolist() for batch in combined]
random.shuffle(combined)
return iter(combined)
def __len__(self):
return (len(self.first_half_indices) + len(self.second_half_indices)) // self.batch_size
但它返回了错误:AttributeError:“SequentialSampler”对象没有属性“set_epoch” 有没有其他方法可以做到这一点,可以使用sampler=DistributeSampler
非常感谢
目前没有回答
相关问题 更多 >
编程相关推荐