在批次中广播张量矩阵乘法

0 投票
1 回答
27 浏览
提问于 2025-04-12 21:43

我该如何找到每个批次的响应和X数据的点积呢?

y_yhat_allBatches_matmulX_allBatches = torch.matmul(yTrue_yHat_allBatches_tensorSub, interceptXY_data_allBatches[:, :, :-1])

我希望y_yhat_allBatches_matmulX_allBatches的形状是2行5列。每一行对应一个特定的批次。

yTrue_yHat_allBatches_tensorSub.shape的形状是[2, 15],这里的行表示批次(1和2),列表示响应的大小(15)。

interceptXY_data_allBatches[:, :, :-1].shape = torch.Size([2, 15, 5])表示有2个批次,每个批次有15个观察值和5个特征。

请查看完整的可复现代码。

#define dataset
nFeatures_withIntercept = 5
NObservations = 15
miniBatches = 2
interceptXY_data_allBatches = torch.randn(miniBatches, NObservations, nFeatures_withIntercept+1) #+1 Y(response variable)

#random assign beta to work with
beta_holder = torch.rand(nFeatures_withIntercept)

#y_predicted for each mini-batch
y_predBatchAllBatches = torch.matmul(interceptXY_data_allBatches[:, :, :-1], beta_holder)

#y_true - y_predicted for each mini-batch
yTrue_yHat_allBatches_tensorSub = torch.sub(interceptXY_data_allBatches[..., -1], y_predBatchAllBatches)
y_yhat_allBatches_matmulX_allBatches = torch.matmul(yTrue_yHat_allBatches_tensorSub, interceptXY_data_allBatches[:, :, :-1])

1 个回答

1

看起来你有:

  • yTrue_yHat_allBatches_tensorSub 的形状是 (2, 15)
  • interceptXY_data_allBatches[:, :, :-1] 的形状是 (2, 15, 5)

如果你想把它们相乘,得到的形状是 (2, 5),那么你需要先把第一个变成 (2, 1, 15)。可以用 .unsqueeze(dim=1) 来实现。然后你可以使用 torch.bmm() 或者 @ 运算符来把 (2, 1, 15) 和 (2, 15, 5) 相乘,这样会得到一个形状为 (2, 1, 5) 的结果。最后,使用 .squeeze 来去掉多余的维度,最终得到 (2, 5)。

y_yhat_allBatches_matmulX_allBatches =\
    torch.bmm(yTrue_yHat_allBatches_tensorSub.unsqueeze(dim=1),
              interceptXY_data_allBatches[:, :, :-1]
             ).squeeze()

使用 @ 运算符的更简洁的写法:

y_yhat_allBatches_matmulX_allBatches =\
    (yTrue_yHat_allBatches_tensorSub.unsqueeze(dim=1) @ interceptXY_data_allBatches[:, :, :-1]).squeeze()

撰写回答