在两个不同形状的torch张量之间寻找赢家单元

0 投票
1 回答
26 浏览
提问于 2025-04-14 17:28

我正在尝试实现一个自组织映射(Self-Organizing Map),目的是根据输入样本与自组织映射之间的L2范数距离,选择出最匹配的单元(也就是胜出单元)。为了实现这个,我写了以下代码:

# Input batch: batch-size = 512, input-dim = 84-
z = torch.randn(512, 84)

# SOM shape: (height, width, input-dim)-
som = torch.randn(40, 40, 84)

# Compute L2 distance for a single sample out of 512 samples-
dist_l2 = np.linalg.norm((som.numpy() - z[0].numpy()), ord = 2, axis = 2)

# dist_l2.shape
# (40, 40)

# Get (row, column) index of the minimum of a 2d np array-
row, col = np.unravel_index(dist_l2.argmin(), dist_l2.shape)

print(f"BMU for z[0]; row = {row}, col  = {col}")
# BMU for z[0]; row = 3, col  = 9

对于第一个输入样本'z',在自组织映射中,胜出单元的索引是(3, 9)。我可以用一个循环来遍历所有512个输入样本,但这样做效率很低。

有没有什么高效的方法,能用PyTorch一次性计算整个批次的结果呢?

1 个回答

1

你可以通过扩展你的 som 张量,轻松地将这个操作应用到批量数据上:

_som = som.view(1,-1,z.size(-1)).expand(len(z),-1,-1)

# L2((512, 1600, 84), (512, 1, 84)) = (512, 1600, 1)
dist_l2 = torch.cdist(_som, z[:,None])[:,:,0]

# both are shaped (512,)
row, col = torch.unravel_index(dist_l2.argmin(1), (40,40))

注意:torch.unravel_index 从 PyTorch 版本 2.2 开始提供,如果你没有这个版本,可以参考这个用户制作的实现。

撰写回答