测试GlobalAttention类和get forward()缺少2个必需的位置参数:“edge_index”和“batch”错误
class GLOGCN(nn.Module):
def __init__(self, gcn, hidden_channels):
super(GLOGCN, self).__init__()
self.gatt= GlobalAttention(gcn)
self.lin = Linear(hidden_channels,dataset.num_classes)
def forward(self, x, batch):
w=self.gatt(x,batch)
x=torch.mm(w,x)
x = self.lin(x)
return x
class GCN(torch.nn.Module):
...
model=GCN(..)
g_model = GLOGCN(model, hidden_channels=1000)
x = model(data.x, data.edge_index, data.batch)
g_out=g_model(x,data.batch)
注意,GlobalAttention需要两个基于docshttps://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.glob.GlobalAttention的参数(x,batch)
目前没有回答
相关问题 更多 >
编程相关推荐