使用java写的矩阵乘法实例(Strassen算法)
Strassen算法于1969年由德国数学家Strassen提出,该方法引入七个中间变量,每个中间变量都只需要进行一次乘法运算。而朴素算法却需要进行8次乘法运算。
原理
Strassen算法的原理如下所示,使用sympy验证Strassen算法的正确性
importsympyass
A=s.Symbol("A")
B=s.Symbol("B")
C=s.Symbol("C")
D=s.Symbol("D")
E=s.Symbol("E")
F=s.Symbol("F")
G=s.Symbol("G")
H=s.Symbol("H")
p1=A*(F-H)
p2=(A+B)*H
p3=(C+D)*E
p4=D*(G-E)
p5=(A+D)*(E+H)
p6=(B-D)*(G+H)
p7=(A-C)*(E+F)
print(A*E+B*G,(p5+p4-p2+p6).simplify())
print(A*F+B*H,(p1+p2).simplify())
print(C*E+D*G,(p3+p4).simplify())
print(C*F+D*H,(p1+p5-p3-p7).simplify())
复杂度分析
$$f(N)=7\timesf(\frac{N}{2})=7^2\timesf(\frac{N}{4})=...=7^k\timesf(\frac{N}{2^k})$$
最终复杂度为$7^{log_2N}=N^{log_27}$
java矩阵乘法(Strassen算法)
代码如下,可以看看数据结构的定义,时间换空间。
publicclassMatrix{
privatefinalMatrix[]_matrixArray;
privatefinalintn;
privateintelement;
publicMatrix(intn){
this.n=n;
if(n!=1){
this._matrixArray=newMatrix[4];
for(inti=0;i<4;i++){
this._matrixArray[i]=newMatrix(n/2);
}
}else{
this._matrixArray=null;
}
}
privateMatrix(intn,booleanneedInit){
this.n=n;
if(n!=1){
this._matrixArray=newMatrix[4];
}else{
this._matrixArray=null;
}
}
publicvoidset(inti,intj,inta){
if(n==1){
element=a;
}else{
intsize=n/2;
this._matrixArray[(i/size)*2+(j/size)].set(i%size,j%size,a);
}
}
publicMatrixmulti(Matrixm){
Matrixresult=null;
if(n==1){
result=newMatrix(1);
result.set(0,0,(element*m.element));
}else{
result=newMatrix(n,false);
result._matrixArray[0]=P5(m).add(P4(m)).minus(P2(m)).add(P6(m));
result._matrixArray[1]=P1(m).add(P2(m));
result._matrixArray[2]=P3(m).add(P4(m));
result._matrixArray[3]=P5(m).add(P1(m)).minus(P3(m)).minus(P7(m));
}
returnresult;
}
publicMatrixadd(Matrixm){
Matrixresult=null;
if(n==1){
result=newMatrix(1);
result.set(0,0,(element+m.element));
}else{
result=newMatrix(n,false);
result._matrixArray[0]=this._matrixArray[0].add(m._matrixArray[0]);
result._matrixArray[1]=this._matrixArray[1].add(m._matrixArray[1]);
result._matrixArray[2]=this._matrixArray[2].add(m._matrixArray[2]);
result._matrixArray[3]=this._matrixArray[3].add(m._matrixArray[3]);;
}
returnresult;
}
publicMatrixminus(Matrixm){
Matrixresult=null;
if(n==1){
result=newMatrix(1);
result.set(0,0,(element-m.element));
}else{
result=newMatrix(n,false);
result._matrixArray[0]=this._matrixArray[0].minus(m._matrixArray[0]);
result._matrixArray[1]=this._matrixArray[1].minus(m._matrixArray[1]);
result._matrixArray[2]=this._matrixArray[2].minus(m._matrixArray[2]);
result._matrixArray[3]=this._matrixArray[3].minus(m._matrixArray[3]);;
}
returnresult;
}
protectedMatrixP1(Matrixm){
return_matrixArray[0].multi(m._matrixArray[1]).minus(_matrixArray[0].multi(m._matrixArray[3]));
}
protectedMatrixP2(Matrixm){
return_matrixArray[0].multi(m._matrixArray[3]).add(_matrixArray[1].multi(m._matrixArray[3]));
}
protectedMatrixP3(Matrixm){
return_matrixArray[2].multi(m._matrixArray[0]).add(_matrixArray[3].multi(m._matrixArray[0]));
}
protectedMatrixP4(Matrixm){
return_matrixArray[3].multi(m._matrixArray[2]).minus(_matrixArray[3].multi(m._matrixArray[0]));
}
protectedMatrixP5(Matrixm){
return(_matrixArray[0].add(_matrixArray[3])).multi(m._matrixArray[0].add(m._matrixArray[3]));
}
protectedMatrixP6(Matrixm){
return(_matrixArray[1].minus(_matrixArray[3])).multi(m._matrixArray[2].add(m._matrixArray[3]));
}
protectedMatrixP7(Matrixm){
return(_matrixArray[0].minus(_matrixArray[2])).multi(m._matrixArray[0].add(m._matrixArray[1]));
}
publicintget(inti,intj){
if(n==1){
returnelement;
}else{
intsize=n/2;
returnthis._matrixArray[(i/size)*2+(j/size)].get(i%size,j%size);
}
}
publicvoiddisplay(){
for(inti=0;i
总结
到此这篇关于使用java写的矩阵乘法的文章就介绍到这了,更多相关java矩阵乘法(Strassen算法)内容请搜索毛票票以前的文章或继续浏览下面的相关文章希望大家以后多多支持毛票票!