您的位置:首页 > 编程语言 > Java开发

分享矩阵乘法单线程与多线程的Java实现与效率对比,请教Strassen算法

2013-09-04 21:50 323 查看
分享矩阵乘法单线程与多线程的Java实现与效率对比,请教Strassen算法

矩阵乘法的多线程实现:


/**


* @Title: MultiThreadMatrix.java


* @Package matrix


* @Description: 多线程计算矩阵乘法


* @author Aloong


* @date 2010-10-28 下午09:45:56


* @version V1.0


*/




package matrix;




import java.util.Date;






public class MultiThreadMatrix


{




static int[][] matrix1;


static int[][] matrix2;


static int[][] matrix3;


static int m,n,k;


static int index;


static int threadCount;


static long startTime;




public static void main(String[] args) throws InterruptedException


{


//矩阵a高度m=100宽度k=80,矩阵b高度k=80宽度n=50 ==> 矩阵c高度m=100宽度n=50


m = 1024;


n = 1024;


k = 1024;


matrix1 = new int[m][k];


matrix2 = new int[k]
;


matrix3 = new int[m]
;




//随机初始化矩阵a,b


fillRandom(matrix1);


fillRandom(matrix2);


startTime = new Date().getTime();




//输出a,b


// printMatrix(matrix1);


// printMatrix(matrix2);




//创建线程,数量 <= 4


for(int i=0; i<4; i++)


{


if(index < m)


{


Thread t = new Thread(new MyThread());


t.start();


}else


{


break;


}


}




//等待结束后输出


while(threadCount!=0)


{


Thread.sleep(20);


}




// printMatrix(matrix3);


long finishTime = new Date().getTime();


System.out.println("计算完成,用时"+(finishTime-startTime)+"毫秒");


}




static void printMatrix(int[][] x)


{


for (int i=0; i<x.length; i++)


{


for(int j=0; j<x[i].length; j++)


{


System.out.print(x[i][j]+" ");


}


System.out.println("");


}


System.out.println("");


}




static void fillRandom(int[][] x)


{


for (int i=0; i<x.length; i++)


{


for(int j=0; j<x[i].length; j++)


{


//每个元素设置为0到99的随机自然数


x[i][j] = (int) (Math.random() * 100);


}


}


}




synchronized static int getTask()


{


if(index < m)


{


return index++;


}


return -1;


}




}




class MyThread implements Runnable


{


int task;


@Override


public void run()


{


MultiThreadMatrix.threadCount++;


while( (task = MultiThreadMatrix.getTask()) != -1 )


{


System.out.println("进程: "+Thread.currentThread().getName()+"\t开始计算第 "+(task+1)+"行");


for(int i=0; i<MultiThreadMatrix.n; i++)


{


for(int j=0; j<MultiThreadMatrix.k; j++)


{


MultiThreadMatrix.matrix3[task][i] += MultiThreadMatrix.matrix1[task][j] * MultiThreadMatrix.matrix2[j][i];


}


}


}


MultiThreadMatrix.threadCount--;


}


}



单线程:


/**


* @Title: SingleThreadMatrix.java


* @Package matrix


* @Description: 单线程计算矩阵乘法


* @author Aloong


* @date 2010-10-28 下午11:33:18


* @version V1.0


*/




package matrix;




import java.util.Date;






public class SingleThreadMatrix


{


static int[][] matrix1;


static int[][] matrix2;


static int[][] matrix3;


static int m,n,k;


static long startTime;




public static void main(String[] args)


{


m = 1024;


n = 1024;


k = 1024;


matrix1 = new int[m][k];


matrix2 = new int[k]
;


matrix3 = new int[m]
;




fillRandom(matrix1);


fillRandom(matrix2);


startTime = new Date().getTime();




//输出a,b


// printMatrix(matrix1);


// printMatrix(matrix2);








for(int task=0; task<m; task++)


{


System.out.println("进程: "+Thread.currentThread().getName()+"\t开始计算第 "+(task+1)+"行");


for(int i=0; i<n; i++)


{


for(int j=0; j<k; j++)


{


matrix3[task][i] += matrix1[task][j] * matrix2[j][i];


}


}


}




// printMatrix(matrix3);


long finishTime = new Date().getTime();


System.out.println("计算完成,用时"+(finishTime-startTime)+"毫秒");


}




static void fillRandom(int[][] x)


{


for (int i=0; i<x.length; i++)


{


for(int j=0; j<x[i].length; j++)


{


//每个元素设置为0到99的随机自然数


x[i][j] = (int) (Math.random() * 100);


}


}


}


}



修改m,n,k的值可以修改相乘矩阵的阶数.

结果对比,计算1024阶矩阵的时候多线程用时约4.8秒,单线程用时16秒,

单线程占用内存21M,多线程占用16M.

本机是4核CPU,单线程的时候只有25%的CPU占用,使用4个子线程可以达到接近100%的CPU使用率.

另外请教一个问题,是矩阵乘法的Strassen算法

下面这个是来自网上的一段代码,在我自己的电脑上,只要超过12阶就会内存溢出

不解是什么原因,设置jvm的内存不管多大也会崩溃在12阶

请高手帮忙解答....




package matrix;




import java.io.*;


import java.util.*;




class Matrix //定义矩阵结构


{


public int[][] m = new int[32][32];


}




public class StrassenMatrix2


{


public int IfIsEven(int n)//判断输入矩阵阶数是否为2^k


{


int a = 0, temp = n;


while (temp % 2 == 0)


{


if (temp % 2 == 0)


temp /= 2;


else


a = 1;


}


if (temp == 1)


a = 0;


return a;


}




public void Divide(Matrix d, Matrix d11, Matrix d12, Matrix d21, Matrix d22, int n)//分解矩阵


{


int i, j;


for (i = 1; i <= n; i++)


for (j = 1; j <= n; j++)


{


d11.m[i][j] = d.m[i][j];


d12.m[i][j] = d.m[i][j + n];


d21.m[i][j] = d.m[i + n][j];


d22.m[i][j] = d.m[i + n][j + n];


}


}




public Matrix Merge(Matrix a11, Matrix a12, Matrix a21, Matrix a22, int n)//合并矩阵


{


int i, j;


Matrix a = new Matrix();


for (i = 1; i <= n; i++)


for (j = 1; j <= n; j++)


{


a.m[i][j] = a11.m[i][j];


a.m[i][j + n] = a12.m[i][j];


a.m[i + n][j] = a21.m[i][j];


a.m[i + n][j + n] = a22.m[i][j];


}


return a;


}




public Matrix TwoMatrixMultiply(Matrix x, Matrix y) //阶数为2的矩阵乘法


{


int m1, m2, m3, m4, m5, m6, m7;


Matrix z = new Matrix();




m1 = (y.m[1][2] - y.m[2][2]) * x.m[1][1];


m2 = y.m[2][2] * (x.m[1][1] + x.m[1][2]);


m3 = (x.m[2][1] + x.m[2][2]) * y.m[1][1];


m4 = x.m[2][2] * (y.m[2][1] - y.m[1][1]);


m5 = (x.m[1][1] + x.m[2][2]) * (y.m[1][1] + y.m[2][2]);


m6 = (x.m[1][2] - x.m[2][2]) * (y.m[2][1] + y.m[2][2]);


m7 = (x.m[1][1] - x.m[2][1]) * (y.m[1][1] + y.m[1][2]);


z.m[1][1] = m5 + m4 - m2 + m6;


z.m[1][2] = m1 + m2;


z.m[2][1] = m3 + m4;


z.m[2][2] = m5 + m1 - m3 - m7;


return z;


}




public Matrix MatrixPlus(Matrix f, Matrix g, int n) //矩阵加法


{


int i, j;


Matrix h = new Matrix();


for (i = 1; i <= n; i++)


for (j = 1; j <= n; j++)


h.m[i][j] = f.m[i][j] + g.m[i][j];


return h;


}




public Matrix MatrixMinus(Matrix f, Matrix g, int n) //矩阵减法方法


{


int i, j;


Matrix h = new Matrix();


for (i = 1; i <= n; i++)


for (j = 1; j <= n; j++)


h.m[i][j] = f.m[i][j] - g.m[i][j];


return h;


}




public Matrix MatrixMultiply(Matrix a, Matrix b, int n) //矩阵乘法方法


{


int k;


Matrix a11, a12, a21, a22;


a11 = new Matrix();


a12 = new Matrix();


a21 = new Matrix();


a22 = new Matrix();


Matrix b11, b12, b21, b22;


b11 = new Matrix();


b12 = new Matrix();


b21 = new Matrix();


b22 = new Matrix();


Matrix c11, c12, c21, c22, c;


c11 = new Matrix();


c12 = new Matrix();


c21 = new Matrix();


c22 = new Matrix();


c = new Matrix();


Matrix m1, m2, m3, m4, m5, m6, m7;


k = n;


if (k == 2)


{


c = TwoMatrixMultiply(a, b);


return c;


} else


{


k = n / 2;


Divide(a, a11, a12, a21, a22, k); //拆分A、B、C矩阵


Divide(b, b11, b12, b21, b22, k);


Divide(c, c11, c12, c21, c22, k);




m1 = MatrixMultiply(a11, MatrixMinus(b12, b22, k), k);


m2 = MatrixMultiply(MatrixPlus(a11, a12, k), b22, k);


m3 = MatrixMultiply(MatrixPlus(a21, a22, k), b11, k);


m4 = MatrixMultiply(a22, MatrixMinus(b21, b11, k), k);


m5 = MatrixMultiply(MatrixPlus(a11, a22, k),


MatrixPlus(b11, b22, k), k);


m6 = MatrixMultiply(MatrixMinus(a12, a22, k),


MatrixPlus(b21, b22, k), k);


m7 = MatrixMultiply(MatrixMinus(a11, a21, k),


MatrixPlus(b11, b12, k), k);


c11 = MatrixPlus(MatrixMinus(MatrixPlus(m5, m4, k), m2, k), m6, k);


c12 = MatrixPlus(m1, m2, k);


c21 = MatrixPlus(m3, m4, k);


c22 = MatrixMinus(MatrixMinus(MatrixPlus(m5, m1, k), m3, k), m7, k);




c = Merge(c11, c12, c21, c22, k); //合并C矩阵


return c;


}


}




public Matrix GetMatrix(Matrix X, int n)


{


int i, j;


X = new Matrix();


for (i = 1; i <= n; i++)


for (j = 1; j <= n; j++)


X.m[i][j] = (int) (Math.random() * 10);


for (i = 1; i <= n; i++)


{


for (j = 1; j <= n; j++)


System.out.print(X.m[i][j] + " ");


System.out.println();


}


return X;


}




public Matrix UsualMatrixMultiply(Matrix A, Matrix B, Matrix C, int n)


{


int i, j, t, k;


for (i = 1; i <= n; i++)


for (j = 1; j <= n; j++)


{


for (k = 1, t = 0; k <= n; k++)


t += A.m[i][k] * B.m[k][j];


C.m[i][j] = t;


}


return C;


}




public static void main(String[] args) throws IOException


{


StrassenMatrix2 instance = new StrassenMatrix2();


int i, j, n;


// Matrix A, B, C, D;


Matrix A, B, C;


A = new Matrix();


B = new Matrix();


C = new Matrix();


// D = new matrix();


Scanner in = new Scanner(System.in);


System.out.print("输入矩阵的阶数: ");


n = in.nextInt();


if (instance.IfIsEven(n) == 0)


{


System.out.println("矩阵A:");


A = instance.GetMatrix(A, n);


System.out.println("矩阵B:");


B = instance.GetMatrix(B, n);


if (n == 1)


C.m[1][1] = A.m[1][1] * B.m[1][1]; //矩阵阶数为1时的特殊处理


else


{


long startTime = new Date().getTime();


C = instance.MatrixMultiply(A, B, n);


long finishTime = new Date().getTime();


System.out.println("计算完成,用时"+(finishTime-startTime)+"毫秒");


}


System.out.println("Strassen矩阵C为:");


for (i = 1; i <= n; i++)


{


for (j = 1; j <= n; j++)


System.out.print(C.m[i][j] + " ");


System.out.println();


}


/* D = instance.UsualMatrixMultiply(A, B, D, n);


System.out.println("普通乘法矩阵D为:");


for (i = 1; i <= n; i++)


{


for (j = 1; j <= n; j++)


System.out.print(D.m[i][j] + " ");


System.out.println();


}*/


} else


System.out.println("输入的阶数不是2的N次方");


}


}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐