判断矩阵中每个元素是否为其转置在三维张量中的倒数
我有一个形状为 m * n * n 的 numpy 矩阵,这意味着我有 m 个 n*n 的方阵。对于每个矩阵,我需要确保元素的排列方式是这样的:转置后的索引是彼此的倒数。如果发现任何不符合这个规则的情况,就返回 False;如果没有问题,就返回 True。
在图片中,答案是对的,因为我们有四个形状为 3*3 的矩阵,并且每个矩阵中的元素都是转置元素的倒数。
这段代码是要投入使用的,所以我希望这个过程能尽量快,并且能最大限度地实现并行处理。
我不知道该怎么开始。我试过 ChatGPT,但它给我的结果很糟糕。
1 个回答
0
首先,请注意,在你的例子中,你的条件可能没有达到你想要的效果,因为你想检查转置索引上的条目是否完全互为倒数。但是 0.11111111 * 9 != 1
。我假设0.11111111是1/9,但1/9在二进制中无法精确存储为浮点数。
无论如何,要检查一个n x n的数组是否满足你的要求,你可以将这个数组和它的转置矩阵逐元素相乘,然后检查这个矩阵中的所有元素是否都是1:
import numpy as np
a = np.array([[1,0.25,4],[4,1,9], [0.25, 1/9, 1]])
b = np.multiply(a, a.T)
# check if any values in b are not 1
print(np.any(b != 1.0)
如果你在检查精确相等时遇到问题(如上所述),你可以考虑使用np.isclose()。
对于你的m x n x n数组,你可以简单地遍历所有的m个数组。可能还有更高效的方法,但据我所知,没有直接的方法可以对3D数组的2D切片应用函数:
def check_reciprocal(mats: np.ndarray) -> bool:
for mat in mats:
if np.any(np.multiply(mat, mat.T) != 1.0):
return False
return True
为了加快速度,可以考虑将2D矩阵重塑为1D数组,然后使用np.apply_along_axis(),或者使用numba编译这个函数。