查找torch张量的邻域

0 投票
1 回答
37 浏览
提问于 2025-04-14 17:26

我正在尝试实现一个自组织映射(Self-Organizing Map),目的是根据输入样本来选择最佳匹配单元(也就是胜出单元)。这个选择是基于输入和自组织映射之间的L2距离。胜出单元(BMU)是指与给定输入(z)之间L2距离最小的那个单元,表示为“som[x, y]”。

# Input batch: batch-size = 512, input-dim = 84-
z = torch.randn(512, 84)

# SOM shape: (height, width, input-dim)-
som = torch.randn(40, 40, 84)

print(f"BMU row, col shapes; row = {row.shape} & col = {col.shape}")
# BMU row, col shapes; row = torch.Size([512]) & col = torch.Size([512])

为了更清楚地说明,对于批次中的第一个输入样本“z[0]”,胜出单元是“som[row[0], col[0]]”。

z[0].shape, som[row[0], col[0]].shape
# (torch.Size([84]), torch.Size([84]))

torch.norm((z[0] - som[row[0], col[0]]))表示z[0]和所有其他自组织映射单元之间的最小L2距离,除了row[0]和col[0]。

# Define initial neighborhood radius and learning rate-
neighb_rad = torch.tensor(2.0)
lr = 0.5

# To update weights for the first input "z[0]" and its corresponding BMU "som[row[0], col[0]]"-
for r in range(som.shape[0]):
    for c in range(som.shape[1]):
        neigh_dist = torch.exp(-torch.norm(input = (som[r, c] - som[row[0], col[0]])) / (2.0 * torch.pow(neighb_rad, 2)))
        som[r, c] = som[r, c] + (lr * neigh_dist * (z[0] - som[r, c]))

我该如何编写代码来:

  1. 更新每个胜出单元周围所有单元的权重,而不使用两个for循环(并且)
  2. 对所有输入“z”进行操作(这里,z有512个样本)

1 个回答

1

你可以把 som 的两个空间维度看作一个扁平的维度。根据你之前的问题,使用 dist_l2.argmin(1)(扁平化的索引)会比使用 rowcol(展开的索引)更好。我们来写一下中间的张量:

# expand batch-wise to (512, 1600, 84)
_som = som.view(1,-1,z.size(-1)).expand(len(z),-1,-1)

# expand z on dim=1 to match som
_z = z[:,None].expand(-1,40*40,-1)

# L2((512, 1600, 84), (512, 1, 84)) = (512, 1600, 1)
dist_l2 = torch.cdist(_som, z[:,None])[:,:,0]

# indices, shape reduced to (512,)
arg = dist_l2.argmin(1)

这样你就可以通过 _som[range(len(arg)),arg] 来获取所有的 som[row[i], col[i]]。在向量化的形式中,你通过引入一个额外的维度(大小为 40*40)来计算所有循环的迭代。此外,我们在 dim=1 上扩展了一个单例:

som_arg = _som[range(len(arg)),arg][:,None]

所以 som[r, c]som[row[i], col[i]] 之间的 L2 距离对应于 torch.cdist(_som, som_arg)

# shape (512, 1600, 84) x (512, 1, 84) -> (512, 1600, 1)
l2_dist = torch.cdist(_som, som_arg)

# shape (512, 1600, 1)
neigh_dist = torch.exp(-l2_dist) / (2.0 * torch.pow(neighb_rad, 2))

# shape (512, 1600, 84) accumulated batch-wise -> (1600, 84)
out = (lr * neigh_dist * (_z - _som)).sum(0)

最后,你可以将其重塑为所需的平方形状: out.view(40,40,-1)

撰写回答