装载机上的重量取样

2024-04-26 13:05:27 发布

您现在位置:Python中文网/ 问答频道 /正文

# sklearn's dataloader, useful for large dataset
# it also does data shuffling
trainDT = CustomDataset(X_train, y_train)
testDT = CustomDataset(X_test, y_test)
global trainloader


#Calculate weights for samples (more emphasis for more recent samples)
#Weightage of 2 for the last 1972 data instances
sample_weights = []
sample_weights = DataFrame(sample_weights)
#sample_weights = torch.tensor(sample_weights)

sample_weights[5000:6971] = 2
sample_weights[0:5000] = 1

#Assign weights to all samples
sample_weights_all = sample_weights[trainDT]


#Add in sample weights for index samples
weighted_sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights_all),
    replacement=False
)


trainloader = DataLoader(trainDT, batch_size=500, shuffle=True,sampler=weighted_sampler)
testloader = DataLoader(testDT, batch_size=500,shuffle=False)

嗨,我想做的是为最后1972个样本分配2倍的权重,如下面的代码所示。发生以下错误。有没有人能建议我如何度过难关。其基本原理是将更多的重点放在最近的样本上(最近的1972年样本)

Traceback (most recent call last):
  File "OptunaTest", line 615, in <module>
    n_instances, n_features, scores = run_analysis()
  File "OptunaTest", line 377, in run_analysis
    sample_weights_all = sample_weights[trainDT]
  File "/home/shar/.local/lib/python3.7/site-packages/pandas/core/frame.py", line 2995, in __getitem__
    indexer = self.columns.get_loc(key)
  File "/home/shar/.local/lib/python3.7/site-packages/pandas/core/indexes/base.py", line 2899, in get_loc
    return self._engine.get_loc(self._maybe_cast_indexer(key))
  File "pandas/_libs/index.pyx", line 107, in pandas._libs.index.IndexEngine.get_loc
  File "pandas/_libs/index.pyx", line 131, in pandas._libs.index.IndexEngine.get_loc
  File "pandas/_libs/hashtable_class_helper.pxi", line 1607, in pandas._libs.hashtable.PyObjectHashTable.get_item
  File "pandas/_libs/hashtable_class_helper.pxi", line 1614, in pandas._libs.hashtable.PyObjectHashTable.get_item
KeyError: <__main__.CustomDataset object at 0x7f0399e82c88>

Tags: sampleinpandasforgetindexlineall