普通的矩阵乘法
根据其定义,易得复杂度为。
分治策略
将规模为的矩阵划分为4个规模为的矩阵,但复杂度与普通算法相同。
Strassen算法
依然采用分治的策略,但以更多的加法操作为代价将每次划分次数由原先的8降为7,使递归树更稀疏,达到。
代码
#include <bits/stdc++.h>
#define rep(i,st,ed) for(int i=st;i<=ed;++i)
using namespace std;
const int N=100;
void add(int **A,int **B,int **res,int size)
{
rep(i,1,size)
rep(j,1,size)
res[i][j]=A[i][j]+B[i][j];
}
void sub(int **A,int **B,int **res,int size)
{
rep(i,1,size)
rep(j,1,size)
res[i][j]=A[i][j]-B[i][j];
}
void strassen(int **A,int **B,int **res,int size)
{
if(size==1)
{
res[1][1]=A[1][1]*B[1][1];
return;
}
int nsize=size>>1;
int **A11,**A12,**A21,**A22;
int **B11,**B12,**B21,**B22;
int **C11,**C12,**C21,**C22;
int **P1,**P2,**P3,**P4,**P5,**P6,**P7;
int **Ares,**Bres;
A11=new int*[nsize];
A12=new int*[nsize];
A21=new int*[nsize];
A22=new int*[nsize];
B11=new int*[nsize];
B12=new int*[nsize];
B21=new int*[nsize];
B22=new int*[nsize];
C11=new int*[nsize];
C12=new int*[nsize];
C21=new int*[nsize];
C22=new int*[nsize];
P1=new int*[nsize];
P2=new int*[nsize];
P3=new int*[nsize];
P4=new int*[nsize];
P5=new int*[nsize];
P6=new int*[nsize];
P7=new int*[nsize];
Ares=new int*[nsize];
Bres=new int*[nsize];
rep(i,1,nsize)
{
A11[i]=new int[nsize];
A12[i]=new int[nsize];
A21[i]=new int[nsize];
A22[i]=new int[nsize];
B11[i]=new int[nsize];
B12[i]=new int[nsize];
B21[i]=new int[nsize];
B22[i]=new int[nsize];
C11[i]=new int[nsize];
C12[i]=new int[nsize];
C21[i]=new int[nsize];
C22[i]=new int[nsize];
P1[i]=new int[nsize];
P2[i]=new int[nsize];
P3[i]=new int[nsize];
P4[i]=new int[nsize];
P5[i]=new int[nsize];
P6[i]=new int[nsize];
P7[i]=new int[nsize];
Ares[i]=new int[nsize];
Bres[i]=new int[nsize];
}
rep(i,1,nsize)
{
rep(j,1,nsize)
{
A11[i][j]=A[i][j];
A12[i][j]=A[i][j+nsize];
A21[i][j]=A[i+nsize][j];
A22[i][j]=A[i+nsize][j+nsize];
B11[i][j]=B[i][j];
B12[i][j]=B[i][j+nsize];
B21[i][j]=B[i+nsize][j];
B22[i][j]=B[i+nsize][j+nsize];
}
}
sub(B12,B22,Bres,nsize);
strassen(A11,Bres,P1,nsize);
add(A11,A12,Ares,nsize);
strassen(Ares,B22,P2,nsize);
add(A21,A22,Ares,nsize);
strassen(Ares,B11,P3,nsize);
sub(B21,B11,Bres,nsize);
strassen(A22,Bres,P4,nsize);
add(A11,A22,Ares,nsize);
add(B11,B22,Bres,nsize);
strassen(Ares,Bres,P5,nsize);
sub(A12,A22,Ares,nsize);
add(B21,B22,Bres,nsize);
strassen(Ares,Bres,P6,nsize);
sub(A11,A21,Ares,nsize);
add(B11,B12,Bres,nsize);
strassen(Ares,Bres,P7,nsize);
add(P5,P4,Ares,nsize);
sub(Ares,P2,Bres,nsize);
add(Bres,P6,C11,nsize);
add(P1,P2,C12,nsize);
add(P3,P4,C21,nsize);
add(P5,P1,C22,nsize);
sub(C22,P3,C22,nsize);
sub(C22,P7,C22,nsize);
rep(i,1,nsize)
{
rep(j,1,nsize)
{
res[i][j]=C11[i][j];
res[i][j+nsize]=C12[i][j];
res[i+nsize][j]=C21[i][j];
res[i+nsize][j+nsize]=C22[i][j];
}
}
rep(i,1,nsize)
{
delete[] A11[i];
delete[] A12[i];
delete[] A21[i];
delete[] A22[i];
delete[] B11[i];
delete[] B12[i];
delete[] B21[i];
delete[] B22[i];
delete[] C11[i];
delete[] C12[i];
delete[] C21[i];
delete[] C22[i];
delete[] P1[i];
delete[] P2[i];
delete[] P3[i];
delete[] P4[i];
delete[] P5[i];
delete[] P6[i];
delete[] P7[i];
delete[] Ares[i];
delete[] Bres[i];
}
delete[] A11;
delete[] A12;
delete[] A21;
delete[] A22;
delete[] B11;
delete[] B12;
delete[] B21;
delete[] B22;
delete[] C11;
delete[] C12;
delete[] C21;
delete[] C22;
delete[] P1;
delete[] P2;
delete[] P3;
delete[] P4;
delete[] P5;
delete[] P6;
delete[] P7;
delete[] Ares;
delete[] Bres;
return;
}
int main()
{
int n;
int **A_,**B_,**C_;
cin>>n;
A_=new int *[n+1];
B_=new int *[n+1];
C_=new int *[n+1];
rep(i,1,n)
{
A_[i]=new int[n+1];
B_[i]=new int[n+1];
C_[i]=new int[n+1];
}
rep(i,1,n)
{
rep(j,1,n)
{
cin>>A_[i][j];
}
}
rep(i,1,n)
{
rep(j,1,n)
{
cin>>B_[i][j];
}
}
strassen(A_,B_,C_,n);
rep(i,1,n)
{
rep(j,1,n)
{
cout<<C_[i][j]<<" ";
}
cout<<endl;
}
}