理解PyTorch模型中的批处理
我有一个模型,它是我整体模型流程中的一个步骤:
import torch
import torch.nn as nn
class NPB(nn.Module):
def __init__(self, d, nhead, num_layers, dropout=0.1):
super(NPB, self).__init__()
self.te = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=d, nhead=nhead, dropout=dropout, batch_first=True),
num_layers=num_layers,
)
self.t_emb = nn.Parameter(torch.randn(1, d))
self.L = nn.Parameter(torch.randn(1, d))
self.td = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=d, nhead=nhead, dropout=dropout, batch_first=True),
num_layers=num_layers,
)
self.ffn = nn.Linear(d, 6)
def forward(self, t_v, t_i):
print("--------------- t_v, t_i -----------------")
print('t_v: ', tuple(t_v.shape))
print('t_i: ', tuple(t_i.shape))
print("--------------- t_v + t_i + t_emb -----------------")
_x = t_v + t_i + self.t_emb
print(tuple(_x.shape))
print("--------------- te ---------------")
_x = self.te(_x)
print(tuple(_x.shape))
print("--------------- td ---------------")
_x = self.td(self.L, _x)
print(tuple(_x.shape))
print("--------------- ffn ---------------")
_x = self.ffn(_x)
print(tuple(_x.shape))
return _x
这里的 t_v
和 t_i
是来自之前编码器模块的输入。我把它们的形状设置为 (4,256)
,其中 256
是特征的数量,4
是批处理的大小。t_emb
是时间嵌入。L
是一个学习到的矩阵,表示查询的嵌入。我用以下代码测试了这个模块:
t_v = torch.randn((4,256))
t_i = torch.randn((4,256))
npb = NPB(d=256, nhead=8, num_layers=2)
npb(t_v, t_i)
它输出了:
=============== NPB ===============
--------------- t_v, t_i -----------------
t_v: (4, 256)
t_i: (4, 256)
--------------- t_v + t_i + t_emb -----------------
(4, 256)
--------------- te ---------------
(4, 256)
--------------- td ---------------
(1, 256)
--------------- ffn ---------------
(1, 6)
我原本期待输出的形状应该是 (4,6)
,也就是每个样本有6个值,批处理的大小是 4
。但是输出的结果是 (1,6)
。经过多次调整,我尝试把 t_emb
和 L
的形状从 (1,d)
改为 (4,d)
,因为我不想让所有样本共享这些变量(通过广播机制:
self.t_emb = nn.Parameter(torch.randn(4, d)) # [n, d] = [4, 256]
self.L = nn.Parameter(torch.randn(4, d))
这样就得到了期望的输出形状 (4,6)
:
--------------- t_v, t_i -----------------
t_v: (4, 256)
t_i: (4, 256)
--------------- t_v + t_i + t_emb -----------------
(4, 256)
--------------- te ---------------
(4, 256)
--------------- td ---------------
(4, 256)
--------------- ffn ---------------
(4, 6)
我有以下疑问:
问题1. 为什么把 L
和 t_emb
的形状从 (1,d)
改为 (4,d)
就有效了?为什么用 (1,d)
通过广播就不行?
问题2. 我这样做批处理的方式对吗?还是说输出看起来是正确的,但实际上在内部做的事情和我预期的(为每个批次大小为4的样本预测6个值)不一样?
1 个回答
0
查看文档 - transformer类,transformer解码器
对于一个没有批处理的(二维)输入,其中 src = (S, E)
和 tgt = (T, E)
,输出的形状将是 (T, E)
。
在transformer解码器层中,第一个参数是 tgt
,它决定了输出的大小。
因为你把 tgt
参数 L
定义为 torch.randn(1, d)
,所以你的transformer解码器输出的大小将是 (1, d)
。
这和广播没有关系,这只是transformer层的输入输出机制。