Numpy矩阵幂/指数与模运算?
有没有办法在使用numpy的linalg.matrix_power时加上一个取模,这样元素就不会变得比某个值大呢?
5 个回答
0
明显的方法有什么问题呢?
比如说:
import numpy as np
x = np.arange(100).reshape(10,10)
y = np.linalg.matrix_power(x, 2) % 50
6
这是使用Numpy库的一个实现:
https://github.com/numpy/numpy/blob/master/numpy/matrixlib/defmatrix.py#L98
我在这个基础上加了一个取模的部分。不过,有一个问题,就是如果发生了溢出,程序不会抛出OverflowError
或者其他任何异常。这样一来,后面的结果就会出错。关于这个问题的报告可以在这里找到。
下面是代码,使用时要小心:
from numpy.core.numeric import concatenate, isscalar, binary_repr, identity, asanyarray, dot
from numpy.core.numerictypes import issubdtype
def matrix_power(M, n, mod_val):
# Implementation shadows numpy's matrix_power, but with modulo included
M = asanyarray(M)
if len(M.shape) != 2 or M.shape[0] != M.shape[1]:
raise ValueError("input must be a square array")
if not issubdtype(type(n), int):
raise TypeError("exponent must be an integer")
from numpy.linalg import inv
if n==0:
M = M.copy()
M[:] = identity(M.shape[0])
return M
elif n<0:
M = inv(M)
n *= -1
result = M % mod_val
if n <= 3:
for _ in range(n-1):
result = dot(result, M) % mod_val
return result
# binary decompositon to reduce the number of matrix
# multiplications for n > 3
beta = binary_repr(n)
Z, q, t = M, 0, len(beta)
while beta[t-q-1] == '0':
Z = dot(Z, Z) % mod_val
q += 1
result = Z
for k in range(q+1, t):
Z = dot(Z, Z) % mod_val
if beta[t-k-1] == '1':
result = dot(result, Z) % mod_val
return result % mod_val
11
为了防止溢出,你可以利用一个事实:如果你先对每个输入数字取模,结果是一样的。实际上:
(M**k) mod p = ([M mod p]**k) mod p,
对于一个矩阵 M
。这个结论来源于以下两个基本公式,这些公式适用于整数 x
和 y
(以及一个正的幂 p):
(x+y) mod p = ([x mod p]+[y mod p]) mod p # All additions can be done on numbers *modulo p*
(x*y) mod p = ([x mod p]*[y mod p]) mod p # All multiplications can be done on numbers *modulo p*
这些公式同样适用于矩阵,因为矩阵的加法和乘法可以通过标量(普通数字)的加法和乘法来表示。这样,你只需要对小数字进行幂运算(n mod p 通常比 n 小得多),就不太可能出现溢出的问题。在 NumPy 中,你可以简单地这样做:
((arr % p)**k) % p
以获得 (arr**k) mod p
。
如果这样做仍然不够(也就是说,如果 [n mod p]**k
仍然有可能导致溢出,尽管 n mod p
很小),你可以把幂运算拆分成多个小的幂运算。上面的基本公式可以得出:
(n**[a+b]) mod p = ([{n mod p}**a mod p] * [{n mod p}**b mod p]) mod p
以及
(n**[a*b]) mod p = ([n mod p]**a mod p)**b mod p.
因此,你可以把幂 k
拆分成 a+b+…
或 a*b*…
或它们的任何组合。上述公式允许你只对小数字进行小数字的幂运算,这大大降低了整数溢出的风险。