查找torch张量的邻域
我正在尝试实现一个自组织映射(Self-Organizing Map),目的是根据输入样本来选择最佳匹配单元(也就是胜出单元)。这个选择是基于输入和自组织映射之间的L2距离。胜出单元(BMU)是指与给定输入(z)之间L2距离最小的那个单元,表示为“som[x, y]”。
# Input batch: batch-size = 512, input-dim = 84-
z = torch.randn(512, 84)
# SOM shape: (height, width, input-dim)-
som = torch.randn(40, 40, 84)
print(f"BMU row, col shapes; row = {row.shape} & col = {col.shape}")
# BMU row, col shapes; row = torch.Size([512]) & col = torch.Size([512])
为了更清楚地说明,对于批次中的第一个输入样本“z[0]”,胜出单元是“som[row[0], col[0]]”。
z[0].shape, som[row[0], col[0]].shape
# (torch.Size([84]), torch.Size([84]))
torch.norm((z[0] - som[row[0], col[0]]))
表示z[0]和所有其他自组织映射单元之间的最小L2距离,除了row[0]和col[0]。
# Define initial neighborhood radius and learning rate-
neighb_rad = torch.tensor(2.0)
lr = 0.5
# To update weights for the first input "z[0]" and its corresponding BMU "som[row[0], col[0]]"-
for r in range(som.shape[0]):
for c in range(som.shape[1]):
neigh_dist = torch.exp(-torch.norm(input = (som[r, c] - som[row[0], col[0]])) / (2.0 * torch.pow(neighb_rad, 2)))
som[r, c] = som[r, c] + (lr * neigh_dist * (z[0] - som[r, c]))
我该如何编写代码来:
- 更新每个胜出单元周围所有单元的权重,而不使用两个for循环(并且)
- 对所有输入“z”进行操作(这里,z有512个样本)
1 个回答
1
你可以把 som
的两个空间维度看作一个扁平的维度。根据你之前的问题,使用 dist_l2.argmin(1)
(扁平化的索引)会比使用 row
和 col
(展开的索引)更好。我们来写一下中间的张量:
# expand batch-wise to (512, 1600, 84)
_som = som.view(1,-1,z.size(-1)).expand(len(z),-1,-1)
# expand z on dim=1 to match som
_z = z[:,None].expand(-1,40*40,-1)
# L2((512, 1600, 84), (512, 1, 84)) = (512, 1600, 1)
dist_l2 = torch.cdist(_som, z[:,None])[:,:,0]
# indices, shape reduced to (512,)
arg = dist_l2.argmin(1)
这样你就可以通过 _som[range(len(arg)),arg]
来获取所有的 som[row[i], col[i]]
。在向量化的形式中,你通过引入一个额外的维度(大小为 40*40
)来计算所有循环的迭代。此外,我们在 dim=1
上扩展了一个单例:
som_arg = _som[range(len(arg)),arg][:,None]
所以 som[r, c]
和 som[row[i], col[i]]
之间的 L2 距离对应于 torch.cdist(_som, som_arg)
:
# shape (512, 1600, 84) x (512, 1, 84) -> (512, 1600, 1)
l2_dist = torch.cdist(_som, som_arg)
# shape (512, 1600, 1)
neigh_dist = torch.exp(-l2_dist) / (2.0 * torch.pow(neighb_rad, 2))
# shape (512, 1600, 84) accumulated batch-wise -> (1600, 84)
out = (lr * neigh_dist * (_z - _som)).sum(0)
最后,你可以将其重塑为所需的平方形状: out.view(40,40,-1)
。