有 Java 编程相关的问题?

你可以在下面搜索框中键入要查询的问题!

Java中两个矩阵相乘的数组

我目前正在开发一个类来表示矩阵,它表示任何通用的mxn矩阵。我已经计算出加法和标量乘法,但我正在努力发展两个矩阵的乘法。矩阵的数据保存在二维双精度数组中

该方法看起来有点像这样:

   public Matrix multiply(Matrix A) {
            ////code
   }

它将返回产品矩阵。这是右边的乘法运算。所以,如果我调用A.multiply(B),它将返回矩阵AB,B在右边

我还不需要担心检查乘法是否定义在给定的矩阵上,我可以假设我将得到正确维数的矩阵

有没有人知道一种简单的算法,甚至可能是伪代码来执行乘法过程

提前谢谢


共 (4) 个答案

  1. # 1 楼答案

    对多个任意维数组尝试此代码并打印它。我认为这更简单,任何人都能理解

    public class Test
    {
    
      public static void main(String[] args)
      {
        int[][] array1 = {
            { 1, 4, -2 },
            { 3, 5, -6 },
            { 4, 5, 2 }
        };
    
        int[][] array2 = {
            { 5, 2, 8, -1 },
            { 3, 6, 4, 5 },
            { -2, 9, 7, -3 }
        };
        Test test = new Test();
        test.printArray(test.multiplication(array1, array2));
      }
    
      private int[][] multiplication(int[][] array1, int[][] array2)
      {
        int r1, r2, c1, c2;
        r1 = array1.length;
        c1 = array1[0].length;
    
        r2 = array2.length;
        c2 = array2[0].length;
    
        int[][] result;
        if (c1 != r2)
        {
          System.out.println("Error!");
          result = new int[0][0];
        }
        else
        {
          result = new int[r1][c2];
    
          for (int i = 0; i < r1; i++)//2
          {
            for (int j = 0; j < c2; j++)//4
            {
              for (int k = 0; k < c1; k++)
              {
                result[i][j] += array1[i][k] * array2[k][j];
              }
            }
          }
        }
    
        return result;
      }
    
      private void printArray(int[][] array)
      {
        for (int[] arr : array)
        {
          for (int element : arr)
          {
            System.out.print(element + " ");
          }
          System.out.println();
        }
      }
    }
    
  2. # 2 楼答案

    在这个答案中,我创建了一个名为Matrix的类,另一个类称为MatrixOperations,它定义了可以对矩阵执行的各种操作(当然行操作除外)。但我将从矩阵运算中提取乘法的代码。完整的项目可以在我的GitHub页面here上找到

    下面是矩阵类的定义

    package app.matrix;
    
    import app.matrix.util.MatrixException;
    
    public class Matrix {
    
    private double[][] entries;
    
    public void setEntries(double[][] entries) {
        this.entries = entries;
    }
    
    private String name;
    
    public double[][] getEntries() {
        return entries;
    }
    
    public String getName() {
        return name;
    }
    
    public void setName(String name) {
        this.name = name;
    }
    
    public class Dimension {
        private int rows;
        private int columns;
    
        public int getRows() {
            return rows;
        }
    
        public void setRows(int rows) {
            this.rows = rows;
        }
    
        public int getColumns() {
            return columns;
        }
    
        public void setColumns(int columns) {
            this.columns = columns;
        }
    
        public Dimension(int rows, int columns) {
            this.setRows(rows);
            this.setColumns(columns);
        }
    
        @Override
        public boolean equals(Object obj) {
            if(obj instanceof Dimension){
                return (this.getColumns() == ((Dimension) obj).getColumns()) && (this.getRows() == ((Dimension) obj).getRows());
            }
            return false;
        }
    }
    
    private Dimension dimension;
    
    public Dimension getDimension() {
        return dimension;
    }
    
    public void setDimension(Dimension dimension) {
        this.dimension = dimension;
    }
    
    public Matrix(int dimension, String name) throws MatrixException {
        if (dimension == 0) throw new MatrixException(ZERO_UNIT_DIMENSION);
        else this.setEntries(new double[Math.abs(dimension)][Math.abs(dimension)]);
        this.setDimension(new Dimension(dimension, dimension));
        this.setName(name);
    }
    
    public Matrix(int dimensionH, int dimensionV, String name) throws MatrixException {
        if (dimensionH == 0 || dimensionV == 0) throw new MatrixException(ZERO_UNIT_DIMENSION);
        else this.setEntries(new double[Math.abs(dimensionH)][Math.abs(dimensionV)]);
        this.setDimension(new Dimension(dimensionH, dimensionV));
        this.setName(name);
    
    }
    
    private static final String OVERFLOW_ITEMS_MSG = "The values are too many for the matrix's specified dimensions";
    private static final String ZERO_UNIT_DIMENSION = "Zero cannot be a value for a dimension";
    
    public Matrix(int dimensionH, int dimensionV, String name, double... values) throws MatrixException {
        if (dimensionH == 0 || dimensionV == 0) throw new MatrixException(ZERO_UNIT_DIMENSION);
        else if (values.length > dimensionH * dimensionV) throw new MatrixException(Matrix.OVERFLOW_ITEMS_MSG);
        else this.setEntries(new double[Math.abs(dimensionH)][Math.abs(dimensionV)]);
        this.setDimension(new Dimension(dimensionH, dimensionV));
        this.setName(name);
    
        int iterator = 0;
        int j;
        for (int i = 0; i < dimensionH; i++) {
            j = 0;
            while (j < dimensionV) {
                this.entries[i][j] = values[iterator];
                j++;
                iterator++;
            }
        }
    }
    
    public Matrix(Dimension dimension) throws MatrixException {
        this(dimension.getRows(), dimension.getColumns(), null);
    }
    
    public static Matrix identityMatrix(int dim) throws MatrixException {
        if (dim == 0) throw new MatrixException(ZERO_UNIT_DIMENSION);
    
        double[] i = new double[dim * dim];
        int constant = dim + 1;
        for (int j = 0; j < i.length; j = j + constant) {
            i[j] = 1.0;
        }
    
        return new Matrix(dim, dim, null, i);
    }
    
    public String toString() {
    
        StringBuilder builder = new StringBuilder("Matrix \"" + (this.getName() == null ? "Null Matrix" : this.getName()) + "\": {\n");
    
        for (int i = 0; i < this.getDimension().getRows(); i++) {
            for (int j = 0; j < this.getDimension().getColumns(); j++) {
                if (j == 0) builder.append("\t");
                builder.append(this.entries[i][j]);
                if (j != this.getDimension().getColumns() - 1)
                    builder.append(", ");
            }
            if (i != this.getDimension().getRows()) builder.append("\n");
        }
    
        builder.append("}");
    
        return builder.toString();
    }
    
    public boolean isSquare() {
        return this.getDimension().getColumns() == this.getDimension().getRows();
    }
    

    }

    这是从矩阵运算得到矩阵乘法的代码方法

    public static Matrix multiply(Matrix matrix1, Matrix matrix2) throws MatrixException {
    
        if (matrix1.getDimension().getColumns() != matrix2.getDimension().getRows())
            throw new MatrixException(MATRIX_MULTIPLICATION_ERROR_MSG);
    
        Matrix retVal = new Matrix(matrix1.getDimension().getRows(), matrix2.getDimension().getColumns(), matrix1.getName() + " x " + matrix2.getName());
    
    
        for (int i = 0; i < matrix1.getDimension().getRows(); i++) {
            for (int j = 0; j < matrix2.getDimension().getColumns(); j++) {
                retVal.getEntries()[i][j] = sum(arrayProduct(matrix1.getEntries()[i], getColumnMatrix(matrix2, j)));
            }
        }
    
        return retVal;
    }
    

    下面是sum、arrayProduct和getColumnMatrix方法的代码

    private static double sum(double... values) {
        double sum = 0;
        for (double value : values) {
            sum += value;
        }
        return sum;
    }
    
    private static double[] arrayProduct(double[] arr1, double[] arr2) throws MatrixException {
        if (arr1.length != arr2.length) throw new MatrixException("Array lengths must be the same");
        double[] retVal = new double[arr1.length];
        for (int i = 0; i < arr1.length; i++) {
            retVal[i] = arr1[i] * arr2[i];
        }
    
        return retVal;
    }
    
    
    private static double[] getColumnMatrix(Matrix matrix, int col) {
        double[] ret = new double[matrix.getDimension().getRows()];
        for (int i = 0; i < matrix.getDimension().getRows(); i++) {
            ret[i] = matrix.getEntries()[i][col];
        }
        return ret;
    }
    
  3. # 3 楼答案

    从数学上讲,矩阵A(l x m)和B(m x n)的乘积被定义为由以下元素组成的矩阵C(l x n):

            m
    c_i_j = ∑  a_i_k * b_k_j
           k=1
    

    因此,如果您没有太多的速度,您可能会对直接的O(n^3)实现感到满意:

      for (int i=0; i<l; ++i)
        for (int j=0; j<n; ++j)
          for (int k=0; k<m; ++k)
            c[i][j] += a[i][k] * b[k][j]  
    

    相反,如果你正在加快速度,你可能需要检查其他替代方法,比如Strassen算法(参见:Strassen算法)

    尽管如此,还是要注意——尤其是在现代处理器体系结构上对小矩阵进行乘法时,速度在很大程度上取决于矩阵数据和乘法顺序的安排,以充分利用缓存线

    我强烈怀疑使用虚拟机是否有可能影响这一因素,所以我不确定是否要考虑到这一点

  4. # 4 楼答案

    Java。矩阵乘法

    下面是“执行乘法过程的代码”。用不同大小的矩阵进行测试

    public class Matrix {
    
    /**
     * Matrix multiplication method.
     * @param m1 Multiplicand
     * @param m2 Multiplier
     * @return Product
     */
        public static double[][] multiplyByMatrix(double[][] m1, double[][] m2) {
            int m1ColLength = m1[0].length; // m1 columns length
            int m2RowLength = m2.length;    // m2 rows length
            if(m1ColLength != m2RowLength) return null; // matrix multiplication is not possible
            int mRRowLength = m1.length;    // m result rows length
            int mRColLength = m2[0].length; // m result columns length
            double[][] mResult = new double[mRRowLength][mRColLength];
            for(int i = 0; i < mRRowLength; i++) {         // rows from m1
                for(int j = 0; j < mRColLength; j++) {     // columns from m2
                    for(int k = 0; k < m1ColLength; k++) { // columns from m1
                        mResult[i][j] += m1[i][k] * m2[k][j];
                    }
                }
            }
            return mResult;
        }
    
        public static String toString(double[][] m) {
            String result = "";
            for(int i = 0; i < m.length; i++) {
                for(int j = 0; j < m[i].length; j++) {
                    result += String.format("%11.2f", m[i][j]);
                }
                result += "\n";
            }
            return result;
        }
    
        public static void main(String[] args) {
            // #1
            double[][] multiplicand = new double[][] {
                    {3, -1, 2},
                    {2,  0, 1},
                    {1,  2, 1}
            };
            double[][] multiplier = new double[][] {
                    {2, -1, 1},
                    {0, -2, 3},
                    {3,  0, 1}
            };
            System.out.println("#1\n" + toString(multiplyByMatrix(multiplicand, multiplier)));
            // #2
            multiplicand = new double[][] {
                    {1, 2, 0},
                    {-1, 3, 1},
                    {2, -2, 1}
            };
            multiplier = new double[][] {
                    {2},
                    {-1},
                    {1}
            };
            System.out.println("#2\n" + toString(multiplyByMatrix(multiplicand, multiplier)));
            // #3
            multiplicand = new double[][] {
                    {1, 2, -1},
                    {0,  1, 0}
            };
            multiplier = new double[][] {
                    {1, 1, 0, 0},
                    {0, 2, 1, 1},
                    {1, 1, 2, 2}
            };
            System.out.println("#3\n" + toString(multiplyByMatrix(multiplicand, multiplier)));
        }
    }
    

    输出:

    #1
          12.00      -1.00       2.00
           7.00      -2.00       3.00
           5.00      -5.00       8.00
    
    #2
           0.00
          -4.00
           7.00
    
    #3
           0.00       4.00       0.00       0.00
           0.00       2.00       1.00       1.00