如何通过匹配索引在PyTorch中获取张量列表的值?
我有一个关于如何用多个索引从张量列表中提取值的问题。
虽然我觉得这里有类似的问题,比如这个,但我还是没能完全理解怎么用。
我有一个数据集,里面有大约108,000个节点及其链接的四维特征。
tmp = []
for _ in range(4):
tmp.append(torch.rand((107940, 4), dtype=torch.float).to(device))
tmp
# [tensor([[0.9249, 0.5367, 0.5161, 0.6898],
# [0.2189, 0.5593, 0.8087, 0.9893],
# [0.4344, 0.1507, 0.4631, 0.7680],
# ...,
# [0.7262, 0.0339, 0.9483, 0.2802],
# [0.8652, 0.3117, 0.8613, 0.6062],
# [0.5434, 0.9583, 0.3032, 0.3919]], device='cuda:0'),
# tensor([...], device='cuda:0'),
# tensor([...], device='cuda:0'),
# tensor([...], device='cuda:0')]
# batch.xxx: factors in the batch from the graph
# Note that batch.edge_index[0] is the target node and batch.edge_index[1] is the source node.
# If you need more information, please see the Pytorch Geometric data format.
print(batch.n_id[batch.edge_index])
print(batch.edge_index_class)
#tensor([[10231, 3059, 32075, 10184, 1187, 6029, 10134, 10173, 6521, 9400,
# 14942, 31065, 10087, 10156, 10158, 26377, 85009, 918, 4542, 10176,
# 10180, 6334, 10245, 10228, 2339, 7891, 10214, 10240, 10041, 10020,
# 7610, 10324, 4320, 5951, 9078, 9709],
# [ 1624, 1624, 6466, 6466, 6779, 6779, 7691, 7691, 8655, 8655,
# 30347, 30347, 32962, 32962, 34435, 34435, 3059, 3059, 32075, 32075,
# 1187, 1187, 6029, 6029, 10173, 10173, 6521, 6521, 9400, 9400,
# 31065, 31065, 10087, 10087, 10158, 10158]], device='cuda:0')
#tensor([3., 3., 2., 2., 0., 0., 3., 3., 2., 2., 0., 0., 2., 2., 2., 2., 3., 3.,
# 2., 2., 0., 0., 0., 0., 3., 3., 2., 2., 2., 2., 0., 0., 2., 2., 2., 2.],
# device='cuda:0')
在这种情况下,我想要一个新的张量,它包含与edge_index_class
匹配的特征值。
举个例子,tmp_filled
会从tmp
的第四个数据集中提取第1624、10231和3059个值,因为它们在edge_index_class
中标记为3。
同样,tmp
的第三个数据集中第6466、32075和10184个值也会放入tmp_filled
的相同索引中。
为此,我尝试了下面的代码:
for k in range(len(batch.edge_index_class)):
tmp_filled[batch.n_id[torch.unique(batch.edge_index)]] = tmp[int(batch.edge_index_class[k].item())][batch.n_id[torch.unique(batch.edge_index)]]
tmp_filled
# tensor([[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.],
# ...,
# [0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.]], device='cuda:0')
但是它返回了错误的结果。
tmp_filled[1624]
# tensor([0.3438, 0.5555, 0.6229, 0.7983], device='cuda:0')
tmp[3][1624]
# tensor([0.6895, 0.3241, 0.1909, 0.1635], device='cuda:0')
当我需要tmp_filled
的数据格式为(107940 x 4)时,我该如何修改我的代码呢?
谢谢你阅读我的问题!
1 个回答
0
下面的代码达到了我想要的效果。
不过如果有人有更有效的解决方案,请随时分享。
for edge_index_class in torch.unique(batch.edge_index_class):
# Find indices where edge_index_class matches
indices = (batch.edge_index_class == edge_index_class).nonzero(as_tuple=True)[0]
# Extract corresponding edge_index and n_id
# edge_index = batch.edge_index[:, indices]
n_id = torch.unique(batch.n_id[batch.edge_index[:, indices]])
tmp_filled[n_id] = tmp[int(edge_index_class.item())][n_id]
tmp_filled[1624]
# tensor([0.6071, 0.9668, 0.9829, 0.1886], device='cuda:0')
tmp[3][1624]
# tensor([0.6071, 0.9668, 0.9829, 0.1886], device='cuda:0')