Pytorch:Node2Vec:TypeError:元组索引必须是整数或切片,而不是元组

2024-06-16 10:20:30 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试从torch_geometric.nn库运行Node2Vec。作为参考,我下面是this示例

在运行train()函数时,我不断得到TypeError: tuple indices must be integers or slices, not tuple

我正在使用torch version 1.6.0CUDA 10.1以及torch-scattertorch-sparsetorch-clustertorch-spline-convtorch-geometric的最新版本

以下是详细的错误:

Part 1 of the Error

Part 2 of the Error

谢谢你的帮助


Tags: ofthe函数示例trainerrornntorch
1条回答
网友
1楼 · 发布于 2024-06-16 10:20:30

错误是由于torch.ops.torch_cluster.random_walk返回的是元组而不是数组/张量。我通过用这些函数替换torch_geometric.nn.Node2Vec中的函数pos_sampleneg_sample修复了它

def pos_sample(self, batch):
    batch = batch.repeat(self.walks_per_node)
    rowptr, col, _ = self.adj.csr()
    rw = random_walk(rowptr, col, batch, self.walk_length, self.p, self.q)
    if not isinstance(rw, torch.Tensor):
        rw = rw[0]

    walks = []
    num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
    for j in range(num_walks_per_rw):
        walks.append(rw[:, j:j + self.context_size])
    return torch.cat(walks, dim=0)


def neg_sample(self, batch):
    batch = batch.repeat(self.walks_per_node * self.num_negative_samples)

    rw = torch.randint(self.adj.sparse_size(0),
                       (batch.size(0), self.walk_length))
    rw = torch.cat([batch.view(-1, 1), rw], dim=-1)

    walks = []
    num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
    for j in range(num_walks_per_rw):
        walks.append(rw[:, j:j + self.context_size])
    return torch.cat(walks, dim=0)

请参阅PyTorch Node2Vecdocumentation

相关问题 更多 >