Skip to content

Latest commit

 

History

History
183 lines (157 loc) · 6.61 KB

File metadata and controls

183 lines (157 loc) · 6.61 KB

Strassen's Matrix Multiplication

Strassen's Algorithm will do one less multiplication when breaking down the matrix, which results in improving the run time (can be proven with recurrence relation).

Strassen’s Matrix multiplication can be performed only on square matrices where n is a power of 2. Order of both of the matrices are n × n.

Explanation

Divide X, Y and Z into four (n/2)×(n/2) matrices as represented below −

Z = {I J X = {A B and Y = {E F K L} C D} G H}

Using Strassen’s Algorithm we compute the following −

M1 = (A+C)×(E+F) M2 = (B+D)×(G+H) M3 = (A−D)×(E+H) M4 = A×(F−H) M5 = (C+D)×(E) M6 = (A+B)×(H) M7 = D×(G−E)

Then,

I = M2+M3−M6−M7 J = M4+M6 K = M5+M7 L = M1−M3−M4−M5

At last we merge all results //merging results to final array

Final Array is productResult

Merge(I,productResult,0,0);
Merge(J,productResult,0,half);
Merge(K,productResult,half,0);
Merge(L,productResult,half,half);

Analysis:

complexity of Strassen’s matrix multiplication algorithm is  O(nlog7)

JAVA CODE

public class Strassens_Algorithm
{
	//classicMatrixMultiplication
	//uses 3 for loops

	public int[][] classicMatrixMultiplication(int[][] matrix1,int[][] matrix2)
	{
		int[][] matrix3 = new int[matrix1.length][matrix1.length];
		for(int i=0;i<matrix1.length;i++)             //deals with rows for matrix3 and 1
		{
			for(int j=0;j<matrix2[0].length;j++)      //deals with columns for matrix 3 and 2
			{
				for(int k=0;k<matrix1[0].length;k++)   //columns for matrix 1, rows for matrix 2
				{
					matrix3[i][j]+=(matrix1[i][k]*matrix2[k][j]); //
				}
			}
		}
		return matrix3;
	}

	//Strassen's Matrix Multiplication
	//uses formulas to calculate the product of two matrices
	//===============================================================================================================================
	public int[][] straussMultiplication(int[][] matrix1,int[][] matrix2)
	{
		int size=matrix1.length;        //reference for recurrsion
		int productResult[][]=new int[size][size];
		if(size==1)              //when the size is easy enough to compute "base case"
		{
			productResult[0][0]=matrix1[0][0]*matrix2[0][0];
		}
		else  //because matrix is 2^(x), if its not small enough, split into 4 parts for each matrix
		{
			int half=size/2;
			int[][] partA11=new int[half][half],partA12=new int[half][half];
			int[][] partA21=new int[half][half],partA22=new int[half][half];//first matrix split into 4 parts
			int[][] partB11=new int[half][half],partB12=new int[half][half];
			int[][] partB21=new int[half][half],partB22=new int[half][half];//second matrix split into 4 parts
			fill(matrix1,partA11,0,0);fill(matrix1,partA12,0,half);
			fill(matrix1,partA21,half,0);fill(matrix1,partA22,half,half);//filling quardrants for matrix1

			fill(matrix2,partB11,0,0);fill(matrix2,partB12,0,half);
			fill(matrix2,partB21,half,0);fill(matrix2,partB22,half,half); //filling quadrants for matrix2

			//following Strauss equations:

			int[][] P=straussMultiplication(straussAdd(partA11,partA22),straussAdd(partB11,partB22));
			int[][] Q=straussMultiplication(straussAdd(partA21,partA22),partB11);
			int[][] R=straussMultiplication(partA11,straussSubtract(partB12,partB22));
			int[][] S=straussMultiplication(partA22,straussSubtract(partB21,partB11));
			int[][] T=straussMultiplication(straussAdd(partA11,partA12),partB22);
			int[][] U=straussMultiplication(straussSubtract(partA21,partA11),straussAdd(partB11,partB12));
			int[][] V=straussMultiplication(straussSubtract(partA12,partA22),straussAdd(partB21,partB22));
			int[][] C11=straussAdd(straussSubtract(straussAdd(P,S),T),V);
			int[][] C12=straussAdd(R,T);
			int[][] C21=straussAdd(Q,S);
			int[][] C22=straussAdd(straussSubtract(straussAdd(P,R),Q),U);
			//merging results to final array
			straussMerge(C11,productResult,0,0);
            straussMerge(C12,productResult,0,half);
			straussMerge(C21,productResult,half,0);
            straussMerge(C22,productResult,half,half);
		}
			return productResult;
	}
	//===============================================================================================================================
	// Merge
	// gets two integer arrays, merges them back together based on position
	//===============================================================================================================================
	private void straussMerge(int[][] matrix1,int[][] result,int part1,int part2)
	{
		//position1 and position 2 will keep track of particular spot in split array
		for(int i=0,position1=part1;i<matrix1.length;i++,position1++)
		{
			for(int j=0,position2=part2;j<matrix1.length;j++,position2++)
			{
				result[position1][position2]=matrix1[i][j];

			}

		}
	}
	//===============================================================================================================================
	//Subtract
	//Standard matrix subtraction
	//===============================================================================================================================
	private int[][] straussSubtract(int[][] matrixA,int[][]matrixB)
	{
		int[][] difference=new int[matrixA.length][matrixA.length];
			for(int i=0;i<matrixA.length;i++)
			{
				for(int j=0;j<matrixA.length;j++)
				{
					difference[i][j]=matrixA[i][j]-matrixB[i][j];
				}
			}
			return difference;
	}
	//===============================================================================================================================
	//Add
	//standard matrix addition
	//===============================================================================================================================
	private int[][] straussAdd(int[][] matrixA,int[][]matrixB)//standard matrix adding
	{
		int[][] sum=new int[matrixA.length][matrixA.length];
			for(int i=0;i<matrixA.length;i++)
			{
				for(int j=0;j<matrixA.length;j++)
				{
					sum[i][j]=matrixA[i][j]+matrixB[i][j];
				}
			}
			return sum;


	}
	//===============================================================================================================================
	// fill
	// will populate the split arrays given specific position
	//===============================================================================================================================
	private void fill(int[][] matrix1,int[][] matrixFill,int pos1,int pos2)
		{
						//position 1 and position 2 will keep track of particular spot of split array
			for(int i=0,position1=pos1;i<matrixFill.length;i++,position1++)
			{
				for(int j=0,position2=pos2;j<matrixFill.length;j++,position2++)
				{
					matrixFill[i][j]=matrix1[position1][position2];
				}
			}
		}

}