如何在PyTorch中平衡不平衡的数据(使用WeightedRandomSampler)?

2024-04-18 18:16:47 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个2类问题,我的数据高度不平衡。我有232550个样本来自一个类,而{}来自第二个类。Pythorch文档和internet告诉我为我的DataLoader使用类WeightedRandomSampler。在

我试过用WeightedRandomSampler,但我总是出错。在

    trainratio = np.bincount(trainset.labels)
    classcount = trainratio.tolist()
    train_weights = 1./torch.tensor(classcount, dtype=torch.float)
    train_sampleweights = train_weights[trainset.labels]
    train_sampler = WeightedRandomSampler(weights=train_sampleweights, 
    num_samples = len(train_sampleweights))
    trainloader = DataLoader(trainset, sampler=train_sampler, 
    shuffle=False)

我不明白为什么在初始化WeightedRandomSampler类时会出现此错误?在

我尝试过其他类似的解决方法,但到目前为止,所有的尝试都会产生一些错误。 我应该如何实现这一点来平衡我的训练、验证和测试数据?在

当前收到此错误:

train__sampleweights = train_weights[trainset.labels] ValueError: too many dimensions 'str'


Tags: 数据labels高度错误traintorch样本sampler