从块对角PyTorch张量中提取块

2024-04-25 14:07:23 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个形状的张量(m*nm*n),我想提取一个大小的张量(nm*n),其中包含对角线上大小n*n的m个块。例如:

>>> a
tensor([[1, 2, 0, 0],
        [3, 4, 0, 0],
        [0, 0, 5, 6],
        [0, 0, 7, 8]])

我想要一个函数extract(a, m, n),它将输出:

>>> extract(a, 2, 2)
tensor([[1, 2, 5, 6],
        [3, 4, 7, 8]])

我考虑过使用某种切片,因为块可以表示为:

>>> for i in range(m):
...     print(a[i*m: i*m + n, i*m: i*m + n])
tensor([[1, 2],
        [3, 4]])
tensor([[5, 6],
        [7, 8]])

Tags: 函数inforextract切片range形状tensor
2条回答

您可以利用reshape和切片:

import torch
import numpy as np

def extract(a, m, n):
  s=(range(m), np.s_[:], range(m), np.s_[:])  # the slices of the blocks
  a.reshape(m, n, m, n)[s]  # reshaping according to blocks and slicing
  return a.reshape(m*n, n).T  # reshape to desired output format

例如:

a = torch.arange(36).reshape(6,6)
a
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35]])

extract(a, 3, 2)

tensor([[ 0,  6, 14, 20, 28, 34],
        [ 1,  7, 15, 21, 29, 35]])

extract(a, 2, 3)

tensor([[ 0,  6, 12, 21, 27, 33],
        [ 1,  7, 13, 22, 28, 34],
        [ 2,  8, 14, 23, 29, 35]])

对于块对角矩阵(宽度为n的大小相等的方形块),可以使用torch.nonzero()实现这一点:

>>> n = 2
>>> a[a.nonzero(as_tuple=True)].view(n,n,-1)
tensor([[[1, 2],
         [3, 4]],

        [[5, 6],
         [7, 8]]])

相关问题 更多 >