如何通过匹配索引在PyTorch中获取张量列表的值?

0 投票
1 回答
33 浏览
提问于 2025-04-14 17:47

我有一个关于如何用多个索引从张量列表中提取值的问题。
虽然我觉得这里有类似的问题,比如这个,但我还是没能完全理解怎么用。

我有一个数据集,里面有大约108,000个节点及其链接的四维特征。

tmp = []
for _ in range(4):
    tmp.append(torch.rand((107940, 4), dtype=torch.float).to(device))

tmp
# [tensor([[0.9249, 0.5367, 0.5161, 0.6898],
#         [0.2189, 0.5593, 0.8087, 0.9893],
#         [0.4344, 0.1507, 0.4631, 0.7680],
#         ...,
#         [0.7262, 0.0339, 0.9483, 0.2802],
#         [0.8652, 0.3117, 0.8613, 0.6062],
#         [0.5434, 0.9583, 0.3032, 0.3919]], device='cuda:0'),
# tensor([...], device='cuda:0'),
# tensor([...], device='cuda:0'),
# tensor([...], device='cuda:0')]
# batch.xxx: factors in the batch from the graph
# Note that batch.edge_index[0] is the target node and batch.edge_index[1] is the source node.
# If you need more information, please see the Pytorch Geometric data format.

print(batch.n_id[batch.edge_index])
print(batch.edge_index_class)

#tensor([[10231,  3059, 32075, 10184,  1187,  6029, 10134, 10173,  6521,  9400,
#         14942, 31065, 10087, 10156, 10158, 26377, 85009,   918,  4542, 10176,
#         10180,  6334, 10245, 10228,  2339,  7891, 10214, 10240, 10041, 10020,
#          7610, 10324,  4320,  5951,  9078,  9709],
#        [ 1624,  1624,  6466,  6466,  6779,  6779,  7691,  7691,  8655,  8655,
#         30347, 30347, 32962, 32962, 34435, 34435,  3059,  3059, 32075, 32075,
#          1187,  1187,  6029,  6029, 10173, 10173,  6521,  6521,  9400,  9400,
#         31065, 31065, 10087, 10087, 10158, 10158]], device='cuda:0')
#tensor([3., 3., 2., 2., 0., 0., 3., 3., 2., 2., 0., 0., 2., 2., 2., 2., 3., 3.,
#        2., 2., 0., 0., 0., 0., 3., 3., 2., 2., 2., 2., 0., 0., 2., 2., 2., 2.],
#       device='cuda:0')

在这种情况下,我想要一个新的张量,它包含与edge_index_class匹配的特征值。
举个例子,tmp_filled会从tmp的第四个数据集中提取第1624、10231和3059个值,因为它们在edge_index_class中标记为3。
同样,tmp的第三个数据集中第6466、32075和10184个值也会放入tmp_filled的相同索引中。

为此,我尝试了下面的代码:

for k in range(len(batch.edge_index_class)):
    tmp_filled[batch.n_id[torch.unique(batch.edge_index)]] = tmp[int(batch.edge_index_class[k].item())][batch.n_id[torch.unique(batch.edge_index)]]

tmp_filled
# tensor([[0., 0., 0., 0.],
#        [0., 0., 0., 0.],
#        [0., 0., 0., 0.],
#        ...,
#        [0., 0., 0., 0.],
#        [0., 0., 0., 0.],
#        [0., 0., 0., 0.]], device='cuda:0')

但是它返回了错误的结果。

tmp_filled[1624]
# tensor([0.3438, 0.5555, 0.6229, 0.7983], device='cuda:0')

tmp[3][1624]
# tensor([0.6895, 0.3241, 0.1909, 0.1635], device='cuda:0')

当我需要tmp_filled的数据格式为(107940 x 4)时,我该如何修改我的代码呢?

谢谢你阅读我的问题!

1 个回答

0

下面的代码达到了我想要的效果。
不过如果有人有更有效的解决方案,请随时分享。

for edge_index_class in torch.unique(batch.edge_index_class):
    # Find indices where edge_index_class matches
    indices = (batch.edge_index_class == edge_index_class).nonzero(as_tuple=True)[0]
    
    # Extract corresponding edge_index and n_id
    # edge_index = batch.edge_index[:, indices]
    n_id = torch.unique(batch.n_id[batch.edge_index[:, indices]])
    
    tmp_filled[n_id] = tmp[int(edge_index_class.item())][n_id]


tmp_filled[1624]
# tensor([0.6071, 0.9668, 0.9829, 0.1886], device='cuda:0')

tmp[3][1624]
# tensor([0.6071, 0.9668, 0.9829, 0.1886], device='cuda:0')

撰写回答