import torch
from torch import nn
class TwoInputsNet(nn.Module):
def __init__(self):
super(TwoInputsNet, self).__init__()
self.conv = nn.Conv2d( ... ) # set up your layer here
self.fc1 = nn.Linear( ... ) # set up first FC layer
self.fc2 = nn.Linear( ... ) # set up the other FC layer
def forward(self, input1, input2):
c = self.conv(input1)
f = self.fc1(input2)
# now we can reshape `c` and `f` to 2D and concat them
combined = torch.cat((c.view(c.size(0), -1),
f.view(f.size(0), -1)), dim=1)
out = self.fc2(combined)
return out
我认为“组合它们”是指concatenate这两个输入。
假设你沿着第二个维度集中:
注意,当您定义},以及{}的输出空间维度。在
self.fc2
的输入数量时,您需要同时考虑self.conv
的{相关问题 更多 >
编程相关推荐