我有一个形状的张量(m*n,m*n),我想提取一个大小的张量(n,m*n),其中包含对角线上大小n*n的m个块。例如:
>>> a
tensor([[1, 2, 0, 0],
[3, 4, 0, 0],
[0, 0, 5, 6],
[0, 0, 7, 8]])
我想要一个函数extract(a, m, n)
,它将输出:
>>> extract(a, 2, 2)
tensor([[1, 2, 5, 6],
[3, 4, 7, 8]])
我考虑过使用某种切片,因为块可以表示为:
>>> for i in range(m):
... print(a[i*m: i*m + n, i*m: i*m + n])
tensor([[1, 2],
[3, 4]])
tensor([[5, 6],
[7, 8]])
您可以利用
reshape
和切片:例如:
对于块对角矩阵(宽度为
n
的大小相等的方形块),可以使用torch.nonzero()
实现这一点:相关问题 更多 >
编程相关推荐