您的位置:首页 > 其它

一个高效的稀疏矩阵转换算法,COO格式转换为DIA格式

2015-01-13 11:13 134 查看
COO格式的矩阵存储格式为:

VAL:存储矩阵的值,VAL[x]表示索引为x的值

INDX:存储行的索引值,例如INDX[x]表示索引为x的值其在矩阵中的行号

JNDX:存储列的索引值,例如JNDX[x]表示索引为x的值其在矩阵中的列号
DIA格式的矩阵存储格式为:

VAL:存储矩阵的值,VAL[x]表示索引为x的值,其中,不足lda长度的对角线需要补零,上三角补后面,下三角补在最前面

INDX:存储对角线的偏移值,例如INDX[x]表示索引为x的对角线其在矩阵中的相对中心对角线的偏移值,大于零位于上三角部分,等于零即中心对角线(最长的那根),小于零位于下三角部分

VAL[行号,列号]<=>VAL[对角线偏移,线内索引]:

上三角:行号=线内索引号,列号=对角线偏移+线内索引

下三角:行号=-对角线偏移+线内索引,列号=线内索引
函数参数含义

m:矩阵总行数

n:矩阵总列数

VAL:COO格式的矩阵非零元元素的值

n_VAL:矩阵的非零元个数

INDX:矩阵行索引

JDNX:矩阵列索引

LDA:对角线长度

NDIAG:对角线数量

代码如下:

void dpre_usconv_coo2dia  (int m,int n,double* VAL,int n_VAL,int* INDX,int* JNDX,int* LDA,int* NDIAG)
{
*NDIAG=0;
*LDA=m<n?m:n;
int num_rows,num_cols,num_nonzeros;
int i,ii,jj,offset,map_index,complete_ndiags,VAL_DIA_size;
const int unmarked = -1;
num_rows = m;
num_cols = n;
num_nonzeros = n_VAL;
complete_ndiags = 0;
int* diag_map = (int*)malloc(sizeof(int)*(num_rows+num_cols));     //mark the diags
int* diag_offset = (int*)malloc(sizeof(int)*(num_cols+num_rows));
memset(diag_map,unmarked,sizeof(int)*(num_rows+num_cols));
memset(diag_offset,unmarked,sizeof(int)*(num_rows+num_cols));

for(i=0;i<num_nonzeros;i++){
ii = INDX[i];
jj = JNDX[i];
map_index = num_rows-ii+jj;            //used to find the same diag
if(diag_map[map_index] == unmarked){
diag_map[map_index] = complete_ndiags;
diag_offset[map_index] = jj - ii;             //get index of diags
complete_ndiags++;                 //number of diags
}
}
*NDIAG = complete_ndiags;
VAL_DIA_size=(*NDIAG)*(*LDA);
int* IDIAG=(int*)malloc(sizeof(int)*(complete_ndiags));
double* VAL_DIA=(double*)malloc(sizeof(double)*VAL_DIA_size);
memset(IDIAG,0,sizeof(int)*(complete_ndiags));
memset(VAL_DIA,0,sizeof(double)*VAL_DIA_size);
for(i=0;i<num_rows + num_cols;i++){
if(diag_map[i] != unmarked){
IDIAG[diag_map[i]] = diag_offset[i];     //offset of diags
}
}

// for(i=0;i<complete_ndiags;i++){
//   printf("%d\t",IDIAG[i]);
// }

for(i=0;i<num_nonzeros;i++){    //get values of every diag
ii = INDX[i];
jj = JNDX[i];
map_index = num_rows-ii+jj;
int diag = diag_map[map_index];
if((*LDA)-(num_rows-ii)>0){
offset=(*LDA)-(num_rows-ii);
}else{
offset=0;
}
VAL_DIA[diag*(*LDA)+offset] = VAL[i];
}

// for(i=0;i<VAL_DIA_size;i++){
//   printf("%lf\t",VAL_DIA[i]);
// }
free(diag_map);
free(diag_offset);
//把结果的值拷贝到VAL,INDX中
//其中VAL_DIA中存放斜线中的值
//INDX中存放斜线号与斜线的偏移值
VAL=VAL_DIA;
INDX=IDIAG;
free(IDIAG);
free(VAL_DIA);
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐