有 Java 编程相关的问题?

你可以在下面搜索框中键入要查询的问题!

math Java 8流矩阵乘法比For循环慢10倍?

我创建了一个使用流执行矩阵乘法的模块。可以在这里找到: https://github.com/firefly-math/firefly-math-linear-real/

我试图编写一个基准测试,以便将流循环实现与Apache Commons Math中相应的for循环实现进行比较

基准模块如下所示: https://github.com/firefly-math/firefly-math-benchmark

这里的实际基准是: https://github.com/firefly-math/firefly-math-benchmark/blob/master/src/main/java/com/fireflysemantics/benchmark/MultiplyBenchmark.java

当我在大小为100X100和1000X1000的矩阵上运行基准测试时,结果表明Apache Commons Math(使用for循环)比相应的流实现快10倍(大致)

# Run complete. Total time: 00:14:10

Benchmark                              Mode  Cnt      Score     Error      Units
MultiplyBenchmark.multiplyCM1000_1000  avgt   30   1040.804 ±  11.796  ms/op
MultiplyBenchmark.multiplyCM100_100    avgt   30      0.790 ±   0.010  ms/op
MultiplyBenchmark.multiplyFM1000_1000  avgt   30  11981.228 ± 405.812  ms/op
MultiplyBenchmark.multiplyFM100_100    avgt   30      7.224 ±   0.685  ms/op

我在基准测试中是否做错了什么(希望:)

我添加了经过测试的方法,这样每个人都可以看到正在比较的内容。这是Apache Commons Math Array2DroArralMatrix。multiply()方法:

/**
 * Returns the result of postmultiplying {@code this} by {@code m}.
 *
 * @param m matrix to postmultiply by
 * @return {@code this * m}
 * @throws DimensionMismatchException if
 * {@code columnDimension(this) != rowDimension(m)}
 */
public Array2DRowRealMatrix multiply(final Array2DRowRealMatrix m)
    throws DimensionMismatchException {
    MatrixUtils.checkMultiplicationCompatible(this, m);

    final int nRows = this.getRowDimension();
    final int nCols = m.getColumnDimension();
    final int nSum = this.getColumnDimension();

    final double[][] outData = new double[nRows][nCols];
    // Will hold a column of "m".
    final double[] mCol = new double[nSum];
    final double[][] mData = m.data;

    // Multiply.
    for (int col = 0; col < nCols; col++) {
        // Copy all elements of column "col" of "m" so that
        // will be in contiguous memory.
        for (int mRow = 0; mRow < nSum; mRow++) {
            mCol[mRow] = mData[mRow][col];
        }

        for (int row = 0; row < nRows; row++) {
            final double[] dataRow = data[row];
            double sum = 0;
            for (int i = 0; i < nSum; i++) {
                sum += dataRow[i] * mCol[i];
            }
            outData[row][col] = sum;
        }
    }

    return new Array2DRowRealMatrix(outData, false);
}

这是相应的流实现:

/**
 * Returns a {@link BinaryOperator} that multiplies {@link SimpleMatrix}
 * {@code m1} times {@link SimpleMatrix} {@code m2} (m1 X m2).
 * 
 * Example {@code multiply(true).apply(m1, m2);}
 * 
 * @param parallel
 *            Whether to perform the operation concurrently.
 * 
 * @throws MathException
 *             Of type {@code MATRIX_DIMENSION_MISMATCH__MULTIPLICATION} if
 *             {@code m} is not the same size as {@code this}.
 * 
 * @return the {@link BinaryOperator} that performs the operation.
 */
public static BinaryOperator<SimpleMatrix> multiply(boolean parallel) {

    return (m1, m2) -> {
        checkMultiplicationCompatible(m1, m2);

        double[][] a1 = m1.toArray();
        double[][] a2 = m2.toArray();

        Stream<double[]> stream = Arrays.stream(a1);
        stream = parallel ? stream.parallel() : stream;

        final double[][] result =
                stream.map(r -> range(0, a2[0].length)
                        .mapToDouble(i -> range(0, a2.length).mapToDouble(j -> r[j]
                                * a2[j][i]).sum())
                        .toArray()).toArray(double[][]::new);

        return new SimpleMatrix(result);
    };
}

蒂亚, 奥立


共 (1) 个答案

  1. # 1 楼答案

    看看DoublePipeline.toArray

    public final double[] toArray() {
      return Nodes.flattenDouble((Node.OfDouble) evaluateToArrayNode(Double[]::new))
                        .asPrimitiveArray();
    }
    

    似乎先创建一个装箱数组,然后将其转换为一个基本数组