在两个不同形状的torch张量之间寻找赢家单元
我正在尝试实现一个自组织映射(Self-Organizing Map),目的是根据输入样本与自组织映射之间的L2范数距离,选择出最匹配的单元(也就是胜出单元)。为了实现这个,我写了以下代码:
# Input batch: batch-size = 512, input-dim = 84-
z = torch.randn(512, 84)
# SOM shape: (height, width, input-dim)-
som = torch.randn(40, 40, 84)
# Compute L2 distance for a single sample out of 512 samples-
dist_l2 = np.linalg.norm((som.numpy() - z[0].numpy()), ord = 2, axis = 2)
# dist_l2.shape
# (40, 40)
# Get (row, column) index of the minimum of a 2d np array-
row, col = np.unravel_index(dist_l2.argmin(), dist_l2.shape)
print(f"BMU for z[0]; row = {row}, col = {col}")
# BMU for z[0]; row = 3, col = 9
对于第一个输入样本'z',在自组织映射中,胜出单元的索引是(3, 9)。我可以用一个循环来遍历所有512个输入样本,但这样做效率很低。
有没有什么高效的方法,能用PyTorch一次性计算整个批次的结果呢?
1 个回答
1
你可以通过扩展你的 som
张量,轻松地将这个操作应用到批量数据上:
_som = som.view(1,-1,z.size(-1)).expand(len(z),-1,-1)
# L2((512, 1600, 84), (512, 1, 84)) = (512, 1600, 1)
dist_l2 = torch.cdist(_som, z[:,None])[:,:,0]
# both are shaped (512,)
row, col = torch.unravel_index(dist_l2.argmin(1), (40,40))
注意:torch.unravel_index
从 PyTorch 版本 2.2 开始提供,如果你没有这个版本,可以参考这个用户制作的实现。