Python 矩阵乘法
我正在尝试创建一个Python程序,用来实现Strassen算法和普通的矩阵乘法方法。但是,当我用createRandom Matrix函数生成的随机矩阵来运行我的Strassen函数时,出现了这个错误:
Traceback (most recent call last):
File "matrixMult.py", line 106, in <module>
print strassen(c, d, 10)
File "matrixMult.py", line 77, in strassen
p1 = strassen(addMatrix(a11,a22), addMatrix(b11,b22), n/2)
File "matrixMult.py", line 78, in strassen
p2 = strassen(addMatrix(a21,a22), b11, n/2)
File "matrixMult.py", line 82, in strassen
p6 = strassen(subMatrix(a21,a11), addMatrix(b11,b12), n/2)
File "matrixMult.py", line 62, in subMatrix
c.append(a[i][j] - b[i][j])
IndexError: list index out of range
这是我的代码。我随机创建了一个10x10的矩阵,然后尝试用它进行Strassen算法,但遇到了之前提到的错误。不过,当我使用我在代码最后定义的简单4x4矩阵时,Strassen算法运行得很好,看起来我的随机矩阵生成没有问题,所以我不太确定问题出在哪里。有没有人有好的建议?
import random
import time
random.seed()
def createEmptyMatrix(x, y): # create empty matrix
matrix = [[0 for row in range(x)] for col in range(y)]
return matrix
def createRandomMatrix(size): # create matrix filled with random ints
matrix = []
matrix = [[random.randint(1,20) for row in range(size)] for col in range(10)]
return matrix
def regular(a, b): # standard O(n^3) matrix multiplication
c = createEmptyMatrix(len(a), len(b[0]))
for i in range(len(a)):
for j in range(len(b[0])):
for k in range(len(b)):
c[i][j] += a[i][k]*b[k][j]
return c
def split(matrix): # split matrix into quarters for strassen
a = matrix
b = matrix
c = matrix
d = matrix
while(len(a) > len(matrix)/2):
a = a[:len(a)/2]
b = b[:len(b)/2]
c = c[len(c)/2:]
d = d[len(d)/2:]
while(len(a[0]) > len(matrix[0])/2):
for i in range(len(a[0])/2):
a[i] = a[i][:len(a[i])/2]
b[i] = b[i][len(b[i])/2:]
c[i] = c[i][:len(c[i])/2]
d[i] = d[i][len(d[i])/2:]
return a,b,c,d
def addMatrix(a, b): # add 2 matrices
d = []
for i in range(len(a)):
c = []
for j in range(len(a[0])):
c.append(a[i][j] + b[i][j])
d.append(c)
return d
def subMatrix(a, b): # subtract 2 matrices
d = []
for i in range(len(a)):
c = []
for j in range(len(a[0])):
c.append(a[i][j] - b[i][j])
d.append(c)
return d
def strassen(a, b, n): # strassen matrix multiplication method
#base case
if n == 1:
d = [[0]]
d[0][0] = a[0][0] * b[0][0]
return d
else:
a11, a12, a21, a22 = split(a)
b11, b12, b21, b22 = split(b)
p1 = strassen(addMatrix(a11,a22), addMatrix(b11,b22), n/2)
p2 = strassen(addMatrix(a21,a22), b11, n/2)
p3 = strassen(a11, subMatrix(b12,b22), n/2)
p4 = strassen(a22, subMatrix(b21,b11), n/2)
p5 = strassen(addMatrix(a11,a12), b22, n/2)
p6 = strassen(subMatrix(a21,a11), addMatrix(b11,b12), n/2)
p7 = strassen(subMatrix(a12,a22), addMatrix(b21,b22), n/2)
c11 = addMatrix(subMatrix(addMatrix(p1, p4), p5), p7)
c12 = addMatrix(p3, p5)
c21 = addMatrix(p2, p4)
c22 = addMatrix(subMatrix(addMatrix(p1, p3), p2), p6)
c = createEmptyMatrix(len(c11)*2,len(c11)*2)
for i in range(len(c11)):
for j in range(len(c11)):
c[i][j] = c11[i][j]
c[i][j+len(c11)] = c12[i][j]
c[i+len(c11)][j] = c21[i][j]
c[i+len(c11)][j+len(c11)] = c22[i][j]
return c
a = [[1,1,1,1],[2,2,2,2],[3,3,3,3],[4,4,4,4]]
b = [[5,5,5,5],[6,6,6,6],[7,7,7,7],[8,8,8,8]]
c = createRandomMatrix(10)
d = createRandomMatrix(10)
print "Strassen Outputs:"
#print strassen(c, d, 10)
print "Should be:"
print regular(c, d)
print c
print d
print a
print b
print strassen(a, b, 4)
2 个回答
0
最后一行的回溯信息告诉你哪里出错了:
File "matrixMult.py", line 62, in subMatrix
c.append(a[i][j] - b[i][j])
IndexError: list index out of range
这一行用了4个数组索引,其中有一个超出了数组的范围。
要调试这个问题,去第62行,在它之前加上一个print i,j
。这样你会得到很多输出,而在异常发生之前的输出行会告诉你哪个索引超出了范围。通过这种方式,你可能能找到你这里的错误。
“就这样调试吧”
1
我建议你使用 numpy
,这个库让你可以很方便地使用矩阵,而且里面已经有很多现成的功能可以用。
同时,如果你在这个函数中遇到索引错误,可以试着加一个 assert
语句来帮助检查:
def subMatrix(a, b): # subtract 2 matrices
assert len(a) == len(b), "Number of rows does not match!"
assert len(a[0]) == len(b[0]), "Number of columns does not match!"
d = []
for i in range(len(a)):
c = []
for j in range(len(a[0])):
c.append(a[i][j] - b[i][j])
d.append(c)
return d
不过其实你根本不需要自己写这个函数:
import numpy as np
a = np.matrix(np.random.randint(10, size=(3,3)))
b = np.matrix(np.random.randint(10, size=(3,))).T
c = a * b
d = a - b
print a
[[5 8 1]
[7 6 1]
[9 2 9]]
print b
[[5]
[2]
[4]]
print c
[[45]
[51]
[85]]
print d
[[ 0 3 -4]
[ 5 4 -1]
[ 5 -2 5]]