Clojure与Numpy的矩阵乘法比较
我正在用Clojure开发一个应用程序,需要对大矩阵进行乘法运算,但遇到了一些性能问题。相比之下,使用Numpy的版本表现得要快得多。Numpy可以在不到一秒的时间内将一个1,000,000x23的矩阵与它的转置相乘,而相同的Clojure代码却要花超过六分钟。(我可以打印出Numpy的结果,所以它肯定是计算了所有内容)。
我在这个Clojure代码中是不是做错了什么?有没有什么Numpy的技巧可以让我尝试模仿一下?
这是Python代码:
import numpy as np
def test_my_mult(n):
A = np.random.rand(n*23).reshape(n,23)
At = A.T
t0 = time.time()
res = np.dot(A.T, A)
print time.time() - t0
print np.shape(res)
return res
# Example (returns a 23x23 matrix):
# >>> results = test_my_mult(1000000)
#
# 0.906938076019
# (23, 23)
这是Clojure代码:
(defn feature-vec [n]
(map (partial cons 1)
(for [x (range n)]
(take 22 (repeatedly rand)))))
(defn dot-product [x y]
(reduce + (map * x y)))
(defn transpose
"returns the transposition of a `coll` of vectors"
[coll]
(apply map vector coll))
(defn matrix-mult
[mat1 mat2]
(let [row-mult (fn [mat row]
(map (partial dot-product row)
(transpose mat)))]
(map (partial row-mult mat2)
mat1)))
(defn test-my-mult
[n afn]
(let [xs (feature-vec n)
xst (transpose xs)]
(time (dorun (afn xst xs)))))
;; Example (yields a 23x23 matrix):
;; (test-my-mult 1000 i/mmult) => "Elapsed time: 32.626 msecs"
;; (test-my-mult 10000 i/mmult) => "Elapsed time: 628.841 msecs"
;; (test-my-mult 1000 matrix-mult) => "Elapsed time: 14.748 msecs"
;; (test-my-mult 10000 matrix-mult) => "Elapsed time: 434.128 msecs"
;; (test-my-mult 1000000 matrix-mult) => "Elapsed time: 375751.999 msecs"
;; Test from wikipedia
;; (def A [[14 9 3] [2 11 15] [0 12 17] [5 2 3]])
;; (def B [[12 25] [9 10] [8 5]])
;; user> (matrix-mult A B)
;; ((273 455) (243 235) (244 205) (102 160))
更新:我使用JBLAS库实现了相同的基准测试,发现速度有了巨大的提升。感谢大家的建议!现在可以把这个代码用Clojure封装起来了。以下是新的代码:
(import '[org.jblas FloatMatrix])
(defn feature-vec [n]
(FloatMatrix.
(into-array (for [x (range n)]
(float-array (cons 1 (take 22 (repeatedly rand))))))))
(defn test-mult [n]
(let [xs (feature-vec n)
xst (.transpose xs)]
(time (let [result (.mmul xst xs)]
[(.rows result)
(.columns result)]))))
;; user> (test-mult 10000)
;; "Elapsed time: 6.99 msecs"
;; [23 23]
;; user> (test-mult 100000)
;; "Elapsed time: 43.88 msecs"
;; [23 23]
;; user> (test-mult 1000000)
;; "Elapsed time: 383.439 msecs"
;; [23 23]
(defn matrix-stream [rows cols]
(repeatedly #(FloatMatrix/randn rows cols)))
(defn square-benchmark
"Times the multiplication of a square matrix."
[n]
(let [[a b c] (matrix-stream n n)]
(time (.mmuli a b c))
nil))
;; forma.matrix.jblas> (square-benchmark 10)
;; "Elapsed time: 0.113 msecs"
;; nil
;; forma.matrix.jblas> (square-benchmark 100)
;; "Elapsed time: 0.548 msecs"
;; nil
;; forma.matrix.jblas> (square-benchmark 1000)
;; "Elapsed time: 107.555 msecs"
;; nil
;; forma.matrix.jblas> (square-benchmark 2000)
;; "Elapsed time: 793.022 msecs"
;; nil
9 个回答
我刚刚对比了一下 Incanter 1.3 和 jBLAS 1.2.1 的性能。下面是我用的代码:
(ns ml-class.experiments.mmult
[:use [incanter core]]
[:import [org.jblas DoubleMatrix]])
(defn -main [m]
(let [n 23 m (Integer/parseInt m)
ai (matrix (vec (double-array (* m n) (repeatedly rand))) n)
ab (DoubleMatrix/rand m n)
ti (copy (trans ai))
tb (.transpose ab)]
(dotimes [i 20]
(print "Incanter: ") (time (mmult ti ai))
(print " jBLAS: ") (time (.mmul tb ab)))))
在我的测试中,Incanter 在普通的矩阵乘法上比 jBLAS 慢了大约 45%。不过,Incanter 的 trans
函数不会创建矩阵的新副本,因此在 jBLAS 中使用 (.mmul (.transpose ab) ab)
会占用两倍的内存,而它的速度只比 Incanter 的 (mmult (trans ai) ai)
快 15%。
考虑到 Incanter 拥有丰富的功能(特别是它的绘图库),我觉得我不会很快切换到 jBLAS。不过,我还是很想看到 jBLAS 和 Parallel Colt 之间的另一次对比,也许可以考虑用 jBLAS 替代 Incanter 中的 Parallel Colt?:-)
编辑:以下是我在我的(相对较慢的)电脑上得到的绝对数字(单位:毫秒):
Incanter: 665.362452
jBLAS: 459.311598
numpy: 353.777885
对于每个库,我从20次运行中选出了最佳时间,矩阵大小为 23x400000。
PS. Haskell 的 hmatrix 结果接近 numpy,但我不确定如何正确地进行基准测试。
Numpy 是一个链接到 BLAS/Lapack 库的工具,这些库经过几十年的优化,能够在机器架构层面上高效运行。而 Clojure 则是以最简单、最直接的方式来实现矩阵乘法。
每当你需要进行复杂的矩阵或向量运算时,最好是使用 BLAS/LAPACK 这些库。
唯一的例外是当你处理的是小矩阵时,如果从某些编程语言中转换数据格式到 LAPACK 的开销比实际计算的时间还要长,那就不一定会更快了。
Python的版本在编译时会变成C语言中的一个循环,而Clojure的版本则是在每次调用map时都生成一个新的中间序列。你看到的性能差异很可能是因为这两种语言使用的数据结构不同。
如果想要更好的性能,可以尝试使用像Incanter这样的库,或者按照这个问题的说明自己写一个版本。还可以看看这个问题,以及neanderthal或nd4j。如果你真的想保持使用序列,以便保留惰性求值的特性等,那么可以研究一下transients,这对内部矩阵计算可能会有很大帮助。
编辑:忘了提到调优Clojure的第一步,就是开启“在反射时警告”。