判断矩阵中每个元素是否为其转置在三维张量中的倒数

0 投票
1 回答
57 浏览
提问于 2025-04-14 17:55

我有一个形状为 m * n * n 的 numpy 矩阵,这意味着我有 m 个 n*n 的方阵。对于每个矩阵,我需要确保元素的排列方式是这样的:转置后的索引是彼此的倒数。如果发现任何不符合这个规则的情况,就返回 False;如果没有问题,就返回 True。

在图片中,答案是对的,因为我们有四个形状为 3*3 的矩阵,并且每个矩阵中的元素都是转置元素的倒数。

这段代码是要投入使用的,所以我希望这个过程能尽量快,并且能最大限度地实现并行处理。

这张图片

我不知道该怎么开始。我试过 ChatGPT,但它给我的结果很糟糕。

1 个回答

0

首先,请注意,在你的例子中,你的条件可能没有达到你想要的效果,因为你想检查转置索引上的条目是否完全互为倒数。但是 0.11111111 * 9 != 1。我假设0.11111111是1/9,但1/9在二进制中无法精确存储为浮点数

无论如何,要检查一个n x n的数组是否满足你的要求,你可以将这个数组和它的转置矩阵逐元素相乘,然后检查这个矩阵中的所有元素是否都是1:

import numpy as np

a = np.array([[1,0.25,4],[4,1,9], [0.25, 1/9, 1]])

b = np.multiply(a, a.T)

# check if any values in b are not 1
print(np.any(b != 1.0) 

如果你在检查精确相等时遇到问题(如上所述),你可以考虑使用np.isclose()

对于你的m x n x n数组,你可以简单地遍历所有的m个数组。可能还有更高效的方法,但据我所知,没有直接的方法可以对3D数组的2D切片应用函数:

def check_reciprocal(mats: np.ndarray) -> bool:
   for mat in mats:
      if np.any(np.multiply(mat, mat.T) != 1.0):
         return False
   return True

为了加快速度,可以考虑将2D矩阵重塑为1D数组,然后使用np.apply_along_axis(),或者使用numba编译这个函数。

撰写回答