Numpy矩阵幂/指数与模运算?

20 投票
5 回答
10464 浏览
提问于 2025-04-17 08:24

有没有办法在使用numpy的linalg.matrix_power时加上一个取模,这样元素就不会变得比某个值大呢?

5 个回答

0

明显的方法有什么问题呢?

比如说:

import numpy as np

x = np.arange(100).reshape(10,10)
y = np.linalg.matrix_power(x, 2) % 50
6

这是使用Numpy库的一个实现:

https://github.com/numpy/numpy/blob/master/numpy/matrixlib/defmatrix.py#L98

我在这个基础上加了一个取模的部分。不过,有一个问题,就是如果发生了溢出,程序不会抛出OverflowError或者其他任何异常。这样一来,后面的结果就会出错。关于这个问题的报告可以在这里找到。

下面是代码,使用时要小心:

from numpy.core.numeric import concatenate, isscalar, binary_repr, identity, asanyarray, dot
from numpy.core.numerictypes import issubdtype    
def matrix_power(M, n, mod_val):
    # Implementation shadows numpy's matrix_power, but with modulo included
    M = asanyarray(M)
    if len(M.shape) != 2 or M.shape[0] != M.shape[1]:
        raise ValueError("input  must be a square array")
    if not issubdtype(type(n), int):
        raise TypeError("exponent must be an integer")

    from numpy.linalg import inv

    if n==0:
        M = M.copy()
        M[:] = identity(M.shape[0])
        return M
    elif n<0:
        M = inv(M)
        n *= -1

    result = M % mod_val
    if n <= 3:
        for _ in range(n-1):
            result = dot(result, M) % mod_val
        return result

    # binary decompositon to reduce the number of matrix
    # multiplications for n > 3
    beta = binary_repr(n)
    Z, q, t = M, 0, len(beta)
    while beta[t-q-1] == '0':
        Z = dot(Z, Z) % mod_val
        q += 1
    result = Z
    for k in range(q+1, t):
        Z = dot(Z, Z) % mod_val
        if beta[t-k-1] == '1':
            result = dot(result, Z) % mod_val
    return result % mod_val
11

为了防止溢出,你可以利用一个事实:如果你先对每个输入数字取模,结果是一样的。实际上:

(M**k) mod p = ([M mod p]**k) mod p,

对于一个矩阵 M。这个结论来源于以下两个基本公式,这些公式适用于整数 xy(以及一个正的幂 p):

(x+y) mod p = ([x mod p]+[y mod p]) mod p  # All additions can be done on numbers *modulo p*
(x*y) mod p = ([x mod p]*[y mod p]) mod p  # All multiplications can be done on numbers *modulo p*

这些公式同样适用于矩阵,因为矩阵的加法和乘法可以通过标量(普通数字)的加法和乘法来表示。这样,你只需要对小数字进行幂运算(n mod p 通常比 n 小得多),就不太可能出现溢出的问题。在 NumPy 中,你可以简单地这样做:

((arr % p)**k) % p

以获得 (arr**k) mod p

如果这样做仍然不够(也就是说,如果 [n mod p]**k 仍然有可能导致溢出,尽管 n mod p 很小),你可以把幂运算拆分成多个小的幂运算。上面的基本公式可以得出:

(n**[a+b]) mod p = ([{n mod p}**a mod p] * [{n mod p}**b mod p]) mod p

以及

(n**[a*b]) mod p = ([n mod p]**a mod p)**b mod p.

因此,你可以把幂 k 拆分成 a+b+…a*b*… 或它们的任何组合。上述公式允许你只对小数字进行小数字的幂运算,这大大降低了整数溢出的风险。

撰写回答