您的位置:首页 > 其它

mpi学习(2)--MPI实现一般矩阵乘法

2015-08-10 14:05 323 查看
思路:

有两个矩阵A,B。从0号进程读入A,B,并且对B做转置(转置算法不太好,这里做转置主要是为了一会计算方便可以提高CACHE的命中率,但不清楚是否以转置浪费的时间为代价是否合理),之后对A矩阵按照按块分割,发送到各个进程。这里调用了MPI_Scatter,因此矩阵大小要可以被进程数整除。然后对每个进程的私有变量做计算。最后由0号进程回收各个部分的值。

代码:github_mpi_multiply

/*
* =====================================================================================
*
*       Filename:  multiply.c
*
*    Description:  矩阵相乘
*
*        Version:  1.0
*        Created:  2015年08月08日 14时54分19秒
*       Revision:  none
*       Compiler:  gcc
*
*         Author:  shiyan (), shiyan233@hotmail.com
*   Organization:
*
* =====================================================================================
*/

#include <stdlib.h>
#include <stdio.h>
#include <time.h>
#include "mpi.h"

/*-----------------------------------------------------------------------------
*  n*m矩阵,n行m列
*-----------------------------------------------------------------------------*/
#define N 4
#define M 4

void Transpose(double *b,double *tmp ,int n, int row, int col);
/*
* ===  FUNCTION  ======================================================================
*         Name:  Read_vec
*  Description:
* =====================================================================================
*/
void
Read_vec (double *local_a,
double *tmp,
int local_n,
int n,
int myrank,
MPI_Comm comm)
{
double *a = NULL;
double *b = NULL;
if (myrank == 0)
{
srand(time(NULL));
a = malloc( n*sizeof(double) );
int i = 0;
printf("a is\n");
for (i=0; i<n;i++)
{
a[i] = rand()%10;
if((i+1)%M == 0)
printf("%lf\n", a[i]);
else printf("%lf, ", a[i]);
}
printf("\nb is \n");
b = malloc( n*sizeof(double) );
i = 0;
for (i=0; i<n;i++)
{
b[i] = rand()%10 ;
if((i+1) % N == 0)
printf("%lf\n",b[i]);
else printf("%lf, ", b[i]);
}
putchar('\n');
Transpose(b, tmp, n, M, N);//将b转置之后放入tmp中
MPI_Bcast(tmp, M*N,MPI_DOUBLE, 0,MPI_COMM_WORLD);
MPI_Scatter(a, local_n, MPI_DOUBLE, local_a, local_n, MPI_DOUBLE, 0, comm);
free(a);
free(b);
}else{
MPI_Bcast(tmp, N*M, MPI_DOUBLE, 0,MPI_COMM_WORLD);
MPI_Scatter(a, local_n, MPI_DOUBLE, local_a, local_n, MPI_DOUBLE, 0, comm);
}
}        /* -----  end of function Read_vec  ----- */

/*
* ===  FUNCTION  ======================================================================
*         Name:  Gather_vec
*  Description:
* =====================================================================================
*/
void
Gather_vec (
double *local_c,
int local_n,
int n,
int myrank,
MPI_Comm comm
)
{
double *t = NULL;
int i = 0;
if (myrank == 0){
t = malloc(N*N*sizeof(double));
MPI_Gather(local_c, N, MPI_DOUBLE, t, N, MPI_DOUBLE, 0, comm);
printf("\nresult:\n");
for(i=0;i<N*N;i++)if((i+1)%4 == 0){
printf("%lf\n", t[i]);
}else printf("%lf ", t[i]);
putchar('\n');
free(t);
}else{
MPI_Gather(local_c, N, MPI_DOUBLE, t, N, MPI_DOUBLE, 0, comm);
}
}        /* -----  end of function Gather_vec  ----- */
/*
* ===  FUNCTION  ======================================================================
*         Name:  transpose
*  Description:  空间O(n)的转置
* =====================================================================================
*/
void Transpose(double *b,double *tmp ,int n, int row, int col)
{
int k = 0;
for (k=0;k<n;k++)
{
tmp[k] = b[ (k%row)*col + k/row ];
}
}        /* -----  end of function transpose  ----- */

/*
* ===  FUNCTION  ======================================================================
*         Name:  Multiply
*  Description:
* =====================================================================================
*/
void
Multiply ()
{
MPI_Init(NULL,NULL);

int myrank;
int size;
MPI_Comm_rank(MPI_COMM_WORLD, &myrank);
MPI_Comm_size(MPI_COMM_WORLD, &size);

int local_n = (M*N)/size;
double *local_a = malloc(local_n*sizeof(double));
double *local_b = malloc(M*N*sizeof(double));
double *local_c = malloc(N*sizeof(double));
Read_vec(local_a, local_b, local_n, (M*N), myrank, MPI_COMM_WORLD);
int i = 0;
int j = 0;
double temp = 0;
for (i=0;i<local_n;i++)
local_c[i] = 0.0;
for(i =0 ;i<N; i++){
temp = 0;
for(j=0;j<local_n;j++)
temp += local_a[j]*local_b[i*M+j];
local_c[i] = temp;
}
Gather_vec(local_c, local_n, M*N, myrank, MPI_COMM_WORLD);
free(local_a);
free(local_b);
free(local_c);
MPI_Finalize();
}        /* -----  end of function Multiply  ----- */
/*
* ===  FUNCTION  ======================================================================
*         Name:  main
*  Description:
* =====================================================================================
*/
int
main ( int argc, char *argv[] )
{
Multiply();
return EXIT_SUCCESS;
}                /* ----------  end of function main  ---------- */


其中使用的几个函数的参数及作用:

1.MPI_Bcast:MPI的广播函数,作用是从源进程向在通信子内的各个进程发送数据。

MPI_Bcast(void *sendbuf, int size, MPI_Type, int src, MPI_Comm);

第一个参数是要发送的数据,第二个为大小,第三个是MPI类型,第四个是源进程号, 第五个是通信子

2.MPI_Scatter(void *send_buf, int send_size, MPI_Type, void *receive_buf,int receive_size,MPI_Type, int src, MPI_Comm )

sendbuf,发送数据,send_size发送的大小,其中这个发送大小是指发送给每个进程的大小,并不是总的大小。

另外之所以在每个进程都执行了MPI_Scatter是因为这样可以降低复杂度,因为每个进程调用可以将复杂度从O(N)降低到O(logn),对于MPI_Bcast也是同样道理。

3.MPI_Gather(void *send_buf, int send_size, MPI_Type, void *receive_buf, int receive_size, MPI_Type, int src, MPI_Comm)

将通信子下每个进程的都接收到src进程中,注意接受的大小一定要大于等于发送的大小。

tip:注意内存泄漏问题。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: