矩阵乘法 Strassen算法

435 阅读3分钟

普通的矩阵乘法

根据其定义,易得复杂度为O(n3)O(n^3)

分治策略

将规模为n×nn\times n的矩阵划分为4个规模为n2×n2\frac{n}{2}\times \frac{n}{2}的矩阵,但复杂度与普通算法相同。

Strassen算法

依然采用分治的策略,但以更多的加法操作为代价将每次划分次数由原先的8降为7,使递归树更稀疏,达到O(nlog27)O(n^{log_27})

代码

#include <bits/stdc++.h>
#define rep(i,st,ed) for(int i=st;i<=ed;++i)
using namespace std;
const int N=100;
void add(int **A,int **B,int **res,int size)
{
	rep(i,1,size)
		rep(j,1,size)
			res[i][j]=A[i][j]+B[i][j];
}
void sub(int **A,int **B,int **res,int size)
{
	rep(i,1,size)
		rep(j,1,size)
			res[i][j]=A[i][j]-B[i][j];
}
void strassen(int **A,int **B,int **res,int size)
{
	if(size==1)
	{
		res[1][1]=A[1][1]*B[1][1];
		return;
	}
	int nsize=size>>1;
	int **A11,**A12,**A21,**A22;
	int **B11,**B12,**B21,**B22;
	int **C11,**C12,**C21,**C22;
	int **P1,**P2,**P3,**P4,**P5,**P6,**P7;
	int **Ares,**Bres;
	A11=new int*[nsize];
	A12=new int*[nsize];
	A21=new int*[nsize];
	A22=new int*[nsize];

	B11=new int*[nsize];
	B12=new int*[nsize];
	B21=new int*[nsize];
	B22=new int*[nsize];

	C11=new int*[nsize];
	C12=new int*[nsize];
	C21=new int*[nsize];
	C22=new int*[nsize];

	P1=new int*[nsize];
	P2=new int*[nsize];
	P3=new int*[nsize];
	P4=new int*[nsize];
	P5=new int*[nsize];
	P6=new int*[nsize];
	P7=new int*[nsize];

	Ares=new int*[nsize];
	Bres=new int*[nsize];

	rep(i,1,nsize)
	{
		A11[i]=new int[nsize];
		A12[i]=new int[nsize];
		A21[i]=new int[nsize];
		A22[i]=new int[nsize];


		B11[i]=new int[nsize];
		B12[i]=new int[nsize];
		B21[i]=new int[nsize];
		B22[i]=new int[nsize];

		C11[i]=new int[nsize];
		C12[i]=new int[nsize];
		C21[i]=new int[nsize];
		C22[i]=new int[nsize];

		P1[i]=new int[nsize];
		P2[i]=new int[nsize];
		P3[i]=new int[nsize];
		P4[i]=new int[nsize];
		P5[i]=new int[nsize];
		P6[i]=new int[nsize];
		P7[i]=new int[nsize];

		Ares[i]=new int[nsize];
		Bres[i]=new int[nsize];
	}
	rep(i,1,nsize)
	{
		rep(j,1,nsize)
		{
			A11[i][j]=A[i][j];
			A12[i][j]=A[i][j+nsize];
			A21[i][j]=A[i+nsize][j];
			A22[i][j]=A[i+nsize][j+nsize];

			B11[i][j]=B[i][j];
			B12[i][j]=B[i][j+nsize];
			B21[i][j]=B[i+nsize][j];
			B22[i][j]=B[i+nsize][j+nsize];
		}
	}

	sub(B12,B22,Bres,nsize);
	strassen(A11,Bres,P1,nsize);

	add(A11,A12,Ares,nsize);
	strassen(Ares,B22,P2,nsize);

	add(A21,A22,Ares,nsize);
	strassen(Ares,B11,P3,nsize);

	sub(B21,B11,Bres,nsize);
	strassen(A22,Bres,P4,nsize);

	add(A11,A22,Ares,nsize);
	add(B11,B22,Bres,nsize);
	strassen(Ares,Bres,P5,nsize);

	sub(A12,A22,Ares,nsize);
	add(B21,B22,Bres,nsize);
	strassen(Ares,Bres,P6,nsize);

	sub(A11,A21,Ares,nsize);
	add(B11,B12,Bres,nsize);
	strassen(Ares,Bres,P7,nsize);

	add(P5,P4,Ares,nsize);
	sub(Ares,P2,Bres,nsize);
	add(Bres,P6,C11,nsize);

	add(P1,P2,C12,nsize);

	add(P3,P4,C21,nsize);

	add(P5,P1,C22,nsize);
	sub(C22,P3,C22,nsize);
	sub(C22,P7,C22,nsize);

	rep(i,1,nsize)
	{
		rep(j,1,nsize)
		{
			res[i][j]=C11[i][j];
			res[i][j+nsize]=C12[i][j];
			res[i+nsize][j]=C21[i][j];
			res[i+nsize][j+nsize]=C22[i][j];
		}
	}

	rep(i,1,nsize)
	{
		delete[] A11[i];
		delete[] A12[i];
		delete[] A21[i];
		delete[] A22[i];
		delete[] B11[i];
		delete[] B12[i];
		delete[] B21[i];
		delete[] B22[i];
		delete[] C11[i];
		delete[] C12[i];
		delete[] C21[i];
		delete[] C22[i];
		delete[] P1[i];
		delete[] P2[i];
		delete[] P3[i];
		delete[] P4[i];
		delete[] P5[i];
		delete[] P6[i];
		delete[] P7[i];
		delete[] Ares[i];
		delete[] Bres[i];
	}
	delete[] A11;
	delete[] A12;
	delete[] A21;
	delete[] A22;
	delete[] B11;
	delete[] B12;
	delete[] B21;
	delete[] B22;
	delete[] C11;
	delete[] C12;
	delete[] C21;
	delete[] C22;
	delete[] P1;
	delete[] P2;
	delete[] P3;
	delete[] P4;
	delete[] P5;
	delete[] P6;
	delete[] P7;
	delete[] Ares;
	delete[] Bres;

	return;
}
int main()
{
	int n;
	int **A_,**B_,**C_;
	cin>>n;
	A_=new int *[n+1];
	B_=new int *[n+1];
	C_=new int *[n+1];
	rep(i,1,n)
	{
		A_[i]=new int[n+1];
		B_[i]=new int[n+1];
		C_[i]=new int[n+1];
	}
	rep(i,1,n)
	{
		rep(j,1,n)
		{

			cin>>A_[i][j];
		}
	}
	rep(i,1,n)
	{
		rep(j,1,n)
		{
			cin>>B_[i][j];
		}
	}
	strassen(A_,B_,C_,n);
	rep(i,1,n)
	{
		rep(j,1,n)
		{
			cout<<C_[i][j]<<" ";
		}
		cout<<endl;
	}
}