java实现任意矩阵Strassen算法
本例输入为两个任意尺寸的矩阵m*n,n*m,输出为两个矩阵的乘积。计算任意尺寸矩阵相乘时,使用了Strassen算法。程序为自编,经过测试,请放心使用。基本算法是:
1.对于方阵(正方形矩阵),找到最大的l,使得l=2^k,k为整数并且l<m。边长为l的方形矩阵则采用Strassen算法,其余部分以及方形矩阵中遗漏的部分用蛮力法。
2.对于非方阵,依照行列相应添加0使其成为方阵。
StrassenMethodTest.java
packagematrixalgorithm; importjava.util.Scanner; publicclassStrassenMethodTest{ privateStrassenMethodstrassenMultiply; StrassenMethodTest(){ strassenMultiply=newStrassenMethod(); }//endcons publicstaticvoidmain(String[]args){ Scannerinput=newScanner(System.in); System.out.println("Inputrowsizeofthefirstmatrix:"); intarow=input.nextInt(); System.out.println("Inputcolumnsizeofthefirstmatrix:"); intacol=input.nextInt(); System.out.println("Inputrowsizeofthesecondmatrix:"); intbrow=input.nextInt(); System.out.println("Inputcolumnsizeofthesecondmatrix:"); intbcol=input.nextInt(); double[][]A=newdouble[arow][acol]; double[][]B=newdouble[brow][bcol]; double[][]C=newdouble[arow][bcol]; System.out.println("InputdataformatrixA:"); /*Inallofthecodeslaterinthisproject, rmeansrowwhilecmeanscolumn. */ for(intr=0;r<arow;r++){ for(intc=0;c<acol;c++){ System.out.printf("DataofA[%d][%d]:",r,c); A[r][c]=input.nextDouble(); }//endinnerloop }//endloop System.out.println("InputdataformatrixB:"); for(intr=0;r<brow;r++){ for(intc=0;c<bcol;c++){ System.out.printf("DataofA[%d][%d]:",r,c); B[r][c]=input.nextDouble(); }//endinnerloop }//endloop StrassenMethodTestalgorithm=newStrassenMethodTest(); C=algorithm.multiplyRectMatrix(A,B,arow,acol,brow,bcol); //Displaythecalculationresult: System.out.println("ResultfrommatrixC:"); for(intr=0;r<arow;r++){ for(intc=0;c<bcol;c++){ System.out.printf("DataofC[%d][%d]:%f\n",r,c,C[r][c]); }//endinnerloop }//endoutterloop }//endmain //Dealwithmatricesthatarenotsquare: publicdouble[][]multiplyRectMatrix(double[][]A,double[][]B, intarow,intacol,intbrow,intbcol){ if(arow!=bcol)//Invalidmultiplicatio returnnewdouble[][]{{0}}; double[][]C=newdouble[arow][bcol]; if(arow<acol){ double[][]newA=newdouble[acol][acol]; double[][]newB=newdouble[brow][brow]; intn=acol; for(intr=0;r<acol;r++) for(intc=0;c<acol;c++) newA[r][c]=0.0; for(intr=0;r<brow;r++) for(intc=0;c<brow;c++) newB[r][c]=0.0; for(intr=0;r<arow;r++) for(intc=0;c<acol;c++) newA[r][c]=A[r][c]; for(intr=0;r<brow;r++) for(intc=0;c<bcol;c++) newB[r][c]=B[r][c]; double[][]C2=multiplySquareMatrix(newA,newB,n); for(intr=0;r<arow;r++) for(intc=0;c<bcol;c++) C[r][c]=C2[r][c]; }//endif elseif(arow==acol) C=multiplySquareMatrix(A,B,arow); else{ intn=arow; double[][]newA=newdouble[arow][arow]; double[][]newB=newdouble[bcol][bcol]; for(intr=0;r<arow;r++) for(intc=0;c<arow;c++) newA[r][c]=0.0; for(intr=0;r<bcol;r++) for(intc=0;c<bcol;c++) newB[r][c]=0.0; for(intr=0;r<arow;r++) for(intc=0;c<acol;c++) newA[r][c]=A[r][c]; for(intr=0;r<brow;r++) for(intc=0;c<bcol;c++) newB[r][c]=B[r][c]; double[][]C2=multiplySquareMatrix(newA,newB,n); for(intr=0;r<arow;r++) for(intc=0;c<bcol;c++) C[r][c]=C2[r][c]; }//endelse returnC; }//endmethod //Dealwithmatricesthataresquarematrices. publicdouble[][]multiplySquareMatrix(double[][]A2,double[][]B2,intn){ double[][]C2=newdouble[n][n]; for(intr=0;r<n;r++) for(intc=0;c<n;c++) C2[r][c]=0; if(n==1){ C2[0][0]=A2[0][0]*B2[0][0]; returnC2; }//endif intexp2k=2; while(exp2k<=(n/2)){ exp2k*=2; }//endloop if(exp2k==n){ C2=strassenMultiply.strassenMultiplyMatrix(A2,B2,n); returnC2; }//endelse //The"biggest"strassenmatrix: double[][][]A=newdouble[6][exp2k][exp2k]; double[][][]B=newdouble[6][exp2k][exp2k]; double[][][]C=newdouble[6][exp2k][exp2k]; for(intr=0;r<exp2k;r++){ for(intc=0;c<exp2k;c++){ A[0][r][c]=A2[r][c]; B[0][r][c]=B2[r][c]; }//endinnerloop }//endoutterloop C[0]=strassenMultiply.strassenMultiplyMatrix(A[0],B[0],exp2k); for(intr=0;r<exp2k;r++) for(intc=0;c<exp2k;c++) C2[r][c]=C[0][r][c]; intmiddle=exp2k/2; for(intr=0;r<middle;r++){ for(intc=exp2k;c<n;c++){ A[1][r][c-exp2k]=A2[r][c]; B[3][r][c-exp2k]=B2[r][c]; }//endinnerloop }//endoutterloop for(intr=exp2k;r<n;r++){ for(intc=0;c<middle;c++){ A[3][r-exp2k][c]=A2[r][c]; B[1][r-exp2k][c]=B2[r][c]; }//endinnerloop }//endoutterloop for(intr=middle;r<exp2k;r++){ for(intc=exp2k;c<n;c++){ A[2][r-middle][c-exp2k]=A2[r][c]; B[4][r-middle][c-exp2k]=B2[r][c]; }//endinnerloop }//endoutterloop for(intr=exp2k;r<n;r++){ for(intc=middle;c<n-exp2k+1;c++){ A[4][r-exp2k][c-middle]=A2[r][c]; B[2][r-exp2k][c-middle]=B2[r][c]; }//endinnerloop }//endoutterloop for(inti=1;i<=4;i++) C[i]=multiplyRectMatrix(A[i],B[i],middle,A[i].length,A[i].length,middle); /* Calculatethefinalresultsofgridsinthe"biggest2^ksquare, accordingtotherulesofmatricemultiplication. */ for(introw=0;row<exp2k;row++){ for(intcol=0;col<exp2k;col++){ for(intk=exp2k;k<n;k++){ C2[row][col]+=A2[row][k]*B2[k][col]; }//endloop }//endinnerloop }//endoutterloop //Usebruteforcetosolvetherest,willbeimprovedlater: for(intcol=exp2k;col<n;col++){ for(introw=0;row<n;row++){ for(intk=0;k<n;k++) C2[row][col]+=A2[row][k]*B2[k][row]; }//endinnerloop }//endoutterloop for(introw=exp2k;row<n;row++){ for(intcol=0;col<exp2k;col++){ for(intk=0;k<n;k++) C2[row][col]+=A2[row][k]*B2[k][row]; }//endinnerloop }//endoutterloop returnC2; }//endmethod }//endclass
StrassenMethod.java
packagematrixalgorithm; importjava.util.Scanner; publicclassStrassenMethod{ privatedouble[][][][]A=newdouble[2][2][][]; privatedouble[][][][]B=newdouble[2][2][][]; privatedouble[][][][]C=newdouble[2][2][][]; /*//Codesfortestingthisclass: publicstaticvoidmain(String[]args){ Scannerinput=newScanner(System.in); System.out.println("Inputsizeofthematrix:"); intn=input.nextInt(); double[][]A=newdouble[n][n]; double[][]B=newdouble[n][n]; double[][]C=newdouble[n][n]; System.out.println("InputdataformatrixA:"); for(intr=0;r<n;r++){ for(intc=0;c<n;c++){ System.out.printf("DataofA[%d][%d]:",r,c); A[r][c]=input.nextDouble(); }//endinnerloop }//endloop System.out.println("InputdataformatrixB:"); for(intr=0;r<n;r++){ for(intc=0;c<n;c++){ System.out.printf("DataofA[%d][%d]:",r,c); B[r][c]=input.nextDouble(); }//endinnerloop }//endloop StrassenMethodalgorithm=newStrassenMethod(); C=algorithm.strassenMultiplyMatrix(A,B,n); System.out.println("ResultfrommatrixC:"); for(intr=0;r<n;r++){ for(intc=0;c<n;c++){ System.out.printf("DataofC[%d][%d]:%f\n",r,c,C[r][c]); }//endinnerloop }//endoutterloop }//endmain*/ publicdouble[][]strassenMultiplyMatrix(double[][]A2,doubleB2[][],intn){ double[][]C2=newdouble[n][n]; //Initializethematrix: for(introwIndex=0;rowIndex<n;rowIndex++) for(intcolIndex=0;colIndex<n;colIndex++) C2[rowIndex][colIndex]=0.0; if(n==1) C2[0][0]=A2[0][0]*B2[0][0]; //"Slicematricesinto2*2parts: else{ double[][][][]A=newdouble[2][2][n/2][n/2]; double[][][][]B=newdouble[2][2][n/2][n/2]; double[][][][]C=newdouble[2][2][n/2][n/2]; for(intr=0;r<n/2;r++){ for(intc=0;c<n/2;c++){ A[0][0][r][c]=A2[r][c]; A[0][1][r][c]=A2[r][n/2+c]; A[1][0][r][c]=A2[n/2+r][c]; A[1][1][r][c]=A2[n/2+r][n/2+c]; B[0][0][r][c]=B2[r][c]; B[0][1][r][c]=B2[r][n/2+c]; B[1][0][r][c]=B2[n/2+r][c]; B[1][1][r][c]=B2[n/2+r][n/2+c]; }//endloop }//endloop n=n/2; double[][][]S=newdouble[10][n][n]; S[0]=minusMatrix(B[0][1],B[1][1],n); S[1]=addMatrix(A[0][0],A[0][1],n); S[2]=addMatrix(A[1][0],A[1][1],n); S[3]=minusMatrix(B[1][0],B[0][0],n); S[4]=addMatrix(A[0][0],A[1][1],n); S[5]=addMatrix(B[0][0],B[1][1],n); S[6]=minusMatrix(A[0][1],A[1][1],n); S[7]=addMatrix(B[1][0],B[1][1],n); S[8]=minusMatrix(A[0][0],A[1][0],n); S[9]=addMatrix(B[0][0],B[0][1],n); double[][][]P=newdouble[7][n][n]; P[0]=strassenMultiplyMatrix(A[0][0],S[0],n); P[1]=strassenMultiplyMatrix(S[1],B[1][1],n); P[2]=strassenMultiplyMatrix(S[2],B[0][0],n); P[3]=strassenMultiplyMatrix(A[1][1],S[3],n); P[4]=strassenMultiplyMatrix(S[4],S[5],n); P[5]=strassenMultiplyMatrix(S[6],S[7],n); P[6]=strassenMultiplyMatrix(S[8],S[9],n); C[0][0]=addMatrix(minusMatrix(addMatrix(P[4],P[3],n),P[1],n),P[5],n); C[0][1]=addMatrix(P[0],P[1],n); C[1][0]=addMatrix(P[2],P[3],n); C[1][1]=minusMatrix(minusMatrix(addMatrix(P[4],P[0],n),P[2],n),P[6],n); n*=2; for(intr=0;r<n/2;r++){ for(intc=0;c<n/2;c++){ C2[r][c]=C[0][0][r][c]; C2[r][n/2+c]=C[0][1][r][c]; C2[n/2+r][c]=C[1][0][r][c]; C2[n/2+r][n/2+c]=C[1][1][r][c]; }//endinnerloop }//endoutterloop }//endelse returnC2; }//endmethod //Addtwomatricesaccordingtomatrixaddition. privatedouble[][]addMatrix(double[][]A,double[][]B,intn){ doubleC[][]=newdouble[n][n]; for(intr=0;r<n;r++) for(intc=0;c<n;c++) C[r][c]=A[r][c]+B[r][c]; returnC; }//endmethod //Substracttwomatricesaccordingtomatrixaddition. privatedouble[][]minusMatrix(double[][]A,double[][]B,intn){ doubleC[][]=newdouble[n][n]; for(intr=0;r<n;r++) for(intc=0;c<n;c++) C[r][c]=A[r][c]-B[r][c]; returnC; }//endmethod }//endclass
希望本文所述对大家学习java程序设计有所帮助。