Pytorch GlobalAttention forward()缺少2个必需的位置参数:“edge_index”和“batch”

2024-04-23 06:54:48 发布

您现在位置:Python中文网/ 问答频道 /正文

测试GlobalAttention类和get forward()缺少2个必需的位置参数:“edge_index”和“batch”错误

class GLOGCN(nn.Module):
    def __init__(self, gcn, hidden_channels):
        super(GLOGCN, self).__init__()
        self.gatt= GlobalAttention(gcn)
        self.lin = Linear(hidden_channels,dataset.num_classes)

    def forward(self, x, batch):
        w=self.gatt(x,batch)
        x=torch.mm(w,x)
        x = self.lin(x)
        return x

class GCN(torch.nn.Module):
    ...

model=GCN(..)
g_model = GLOGCN(model, hidden_channels=1000)

x = model(data.x, data.edge_index, data.batch)  
g_out=g_model(x,data.batch)

注意,GlobalAttention需要两个基于docshttps://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.glob.GlobalAttention的参数(x,batch)


Tags: selfdata参数indexmodelbatchnntorch