Python中使用除法与征服的矩阵乘法

2024-04-25 17:23:16 发布

您现在位置:Python中文网/ 问答频道 /正文

我用python3编写了一个程序来找出2n*n矩阵的乘积(其中n是2的幂次)。在

为什么下面的代码不起作用并显示IndexError: invalid index to scalar variable?在

import numpy as np

def product(x, y, k):
    def fsum(p, q, m):
        r = [[p[i, j] + q[i, j] for j in range(m)] for i in range(m)]
        return r

    if k == 1:
        return x[0][0] * y[0][0]
    else:
        A = x[0:(k // 2), 0:(k // 2)]
        B = x[0:(k // 2), (k // 2):k]
        C = x[(k // 2):k, 0:(k // 2)]
        D = x[(k // 2):k, (k // 2):k]

        E = y[0:(k // 2), 0:(k // 2)]
        F = y[0:(k // 2), (k // 2):k]
        G = y[(k // 2):k, 0:(k // 2)]
        H = y[(k // 2):k, (k // 2):k]

        C00 = fsum(product(A, E, k // 2), product(B, G, k // 2), k // 2)
        C01 = fsum(product(A, F, k // 2), product(B, H, k // 2), k // 2)
        C10 = fsum(product(C, E, k // 2), product(D, G, k // 2), k // 2)
        C11 = fsum(product(C, F, k // 2), product(D, H, k // 2), k // 2)

        return np.array([[C00, C01], [C10, C11]])

n = int(input('Enter index(power of 2): '))
print('Input 1st matrix')
a = np.array([[int(_) for _ in input().split()] for x in range(n)])
print('Input 2nd matrix')
b = np.array([[int(_) for _ in input().split()] for x in range(n)])
print(product(a, b, n))

Tags: inforinputindexreturndefnprange
2条回答

运行示例:

Enter index(power of 2): 2
Input 1st matrix
1 2
3 4
[[1 2]
 [3 4]]
Input 2nd matrix
3 4
5 6
[[3 4]
 [5 6]]
Traceback (most recent call last):
  File "stack51061196.py", line 35, in <module>
    print(product(a, b, n))
  File "stack51061196.py", line 21, in product
    C00 = fsum(product(A, E, k // 2), product(B, G, k // 2), k // 2)
  File "stack51061196.py", line 5, in fsum
    r = [[p[i, j] + q[i, j] for j in range(m)] for i in range(m)]
  File "stack51061196.py", line 5, in <listcomp>
    r = [[p[i, j] + q[i, j] for j in range(m)] for i in range(m)]
  File "stack51061196.py", line 5, in <listcomp>
    r = [[p[i, j] + q[i, j] for j in range(m)] for i in range(m)]
IndexError: invalid index to scalar variable.

print('p,q', p, q, type(p), type(q))添加到fsum

^{pr2}$

所以p是一个np.int64对象,而不是数组。它已经被索引了,不能再进一步了。在

In [193]: x = np.array([1])[0]
In [194]: x
Out[194]: 1
In [195]: type(x)
Out[195]: numpy.int64
In [196]: x[0,0]
IndexError: invalid index to scalar variable.

另一个诊断指纹

print('product AE',k,type(product(A, E, k // 2)))

显示器

product AE 2 <class 'numpy.int64'>

所以当k为2时,product返回一个scalar variable,一个int64对象,而不是数组。将其传递给fsum会产生错误。在

正是product的这个分支导致了问题-x[0][0](为什么不{}?)是x的元素:

if k == 1:
    return x[0][0] * y[0][0]

此时x是(1,1)形状,所以您可以只写

if k == 1:
     return x * y

有了这个变化,我得到了:

0840:~/mypy$ python3 stack51061196.py 
a,b [[1 2]
 [3 4]] [[1 2]
 [3 4]]
A [[1]] <class 'numpy.ndarray'>
product AE 2 <class 'numpy.ndarray'>
p,q [[1]] [[6]] <class 'numpy.ndarray'> <class 'numpy.ndarray'>
p,q [[2]] [[8]] <class 'numpy.ndarray'> <class 'numpy.ndarray'>
p,q [[3]] [[12]] <class 'numpy.ndarray'> <class 'numpy.ndarray'>
p,q [[6]] [[16]] <class 'numpy.ndarray'> <class 'numpy.ndarray'>
[[[[ 7]]

  [[10]]]


 [[[15]]

  [[22]]]]

除了尺寸匹配:

In [197]: a=np.array([[1,2],[3,4]])
In [198]: np.dot(a,a)
Out[198]: 
array([[ 7, 10],
       [15, 22]])

您的形状是(2, 2, 1, 1),可以用squeeze删除它,但是您确实应该优化迭代,以便在没有它的情况下获得正确的形状。在

{4x4>数组现在也得到了形状。在

以下是一种方法:

def product(X, Y):
    k = len(X)
    if k == 1:
        return X * Y
    (A, B), (C, D) = skimage.util.view_as_blocks(X, block_shape=(k // 2, k // 2))
    (E, F), (G, H) = skimage.util.view_as_blocks(Y, block_shape=(k // 2, k // 2))

    out = np.zeros((k, k))
    (I, J), (K, L) = skimage.util.view_as_blocks(out, block_shape=(k // 2, k // 2))
    I[:] = product(A, E) + product(B, G)
    J[:] = product(A, F) + product(B, H)
    K[:] = product(C, E) + product(D, G)
    L[:] = product(C, F) + product(D, H)
    return out

不用说它非常慢

^{pr2}$

相关问题 更多 >