理解PyTorch模型中的批处理

0 投票
1 回答
33 浏览
提问于 2025-04-12 04:26

我有一个模型,它是我整体模型流程中的一个步骤:

import torch
import torch.nn as nn

class NPB(nn.Module):
    def __init__(self, d, nhead, num_layers, dropout=0.1):
        super(NPB, self).__init__()
            
        self.te = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=d, nhead=nhead, dropout=dropout, batch_first=True),
            num_layers=num_layers,
        ) 

        self.t_emb = nn.Parameter(torch.randn(1, d))
        
        self.L = nn.Parameter(torch.randn(1, d)) 

        self.td = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=d, nhead=nhead, dropout=dropout, batch_first=True),
            num_layers=num_layers,
        ) 

        self.ffn = nn.Linear(d, 6)
    
    def forward(self, t_v, t_i):
        print("--------------- t_v, t_i -----------------")
        print('t_v: ', tuple(t_v.shape))
        print('t_i: ', tuple(t_i.shape))

        print("--------------- t_v + t_i + t_emb -----------------")
        _x = t_v + t_i + self.t_emb
        print(tuple(_x.shape))

        print("--------------- te ---------------")
        _x = self.te(_x)
        print(tuple(_x.shape))
        
        print("--------------- td ---------------")
        _x = self.td(self.L, _x)
        print(tuple(_x.shape))

        print("--------------- ffn ---------------")
        _x = self.ffn(_x)
        print(tuple(_x.shape))

        return _x

这里的 t_vt_i 是来自之前编码器模块的输入。我把它们的形状设置为 (4,256),其中 256 是特征的数量,4 是批处理的大小。t_emb 是时间嵌入。L 是一个学习到的矩阵,表示查询的嵌入。我用以下代码测试了这个模块:

t_v = torch.randn((4,256))
t_i = torch.randn((4,256))
npb = NPB(d=256, nhead=8, num_layers=2)
npb(t_v, t_i)

它输出了:

=============== NPB ===============
--------------- t_v, t_i -----------------
t_v:  (4, 256)
t_i:  (4, 256)
--------------- t_v + t_i + t_emb -----------------
(4, 256)
--------------- te ---------------
(4, 256)
--------------- td ---------------
(1, 256)
--------------- ffn ---------------
(1, 6)

我原本期待输出的形状应该是 (4,6),也就是每个样本有6个值,批处理的大小是 4。但是输出的结果是 (1,6)。经过多次调整,我尝试把 t_embL 的形状从 (1,d) 改为 (4,d),因为我不想让所有样本共享这些变量(通过广播机制:

self.t_emb = nn.Parameter(torch.randn(4, d)) # [n, d] = [4, 256]     
self.L = nn.Parameter(torch.randn(4, d)) 

这样就得到了期望的输出形状 (4,6)

--------------- t_v, t_i -----------------
t_v:  (4, 256)
t_i:  (4, 256)
--------------- t_v + t_i + t_emb -----------------
(4, 256)
--------------- te ---------------
(4, 256)
--------------- td ---------------
(4, 256)
--------------- ffn ---------------
(4, 6)

我有以下疑问:

问题1. 为什么把 Lt_emb 的形状从 (1,d) 改为 (4,d) 就有效了?为什么用 (1,d) 通过广播就不行?
问题2. 我这样做批处理的方式对吗?还是说输出看起来是正确的,但实际上在内部做的事情和我预期的(为每个批次大小为4的样本预测6个值)不一样?

1 个回答

0

查看文档 - transformer类transformer解码器

对于一个没有批处理的(二维)输入,其中 src = (S, E)tgt = (T, E),输出的形状将是 (T, E)

在transformer解码器层中,第一个参数是 tgt,它决定了输出的大小。

因为你把 tgt 参数 L 定义为 torch.randn(1, d),所以你的transformer解码器输出的大小将是 (1, d)

这和广播没有关系,这只是transformer层的输入输出机制。

撰写回答