您的位置:首页 > 其它

【4920Matrix multiplication】矩阵乘法优化+输入挂

2014-08-05 22:41 387 查看


Matrix multiplication

Time Limit: 4000/2000 MS (Java/Others)    Memory Limit: 131072/131072 K (Java/Others)

Total Submission(s): 1121    Accepted Submission(s): 474


Problem Description

Given two matrices A and B of size n×n, find the product of them.

bobo hates big integers. So you are only asked to find the result modulo 3.

 

Input

The input consists of several tests. For each tests:

The first line contains n (1≤n≤800). Each of the following n lines contain n integers -- the description of the matrix A. The j-th integer in the i-th line equals Aij. The next n lines describe the matrix B in similar format (0≤Aij,Bij≤109).

 

Output

For each tests:

Print n lines. Each of them contain n integers -- the matrix A×B in similar format.

 

Sample Input

1
0
1
2
0 1
2 3
4 5
6 7

 

Sample Output

0
0 1
2 1

 

Author

Xiaoxu Guo (ftiasch)

 

Source

2014 Multi-University Training Contest 5

坑爹啊~~~还尼玛有卡常数的,Strassen写到哭有木有啊!!

#define DeBUG
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cstdlib>
#include <algorithm>
#include <vector>
#include <stack>
#include <queue>
#include <string>
#include <set>
#include <sstream>
#include <map>
#include <list>
#include <bitset>
using namespace std ;
#define zero {0}
#define INF 0x3f3f3f3f
#define EPS 1e-6
#define TRUE true
#define FALSE false
typedef long long LL;
const double PI = acos(-1.0);
//#pragma comment(linker, "/STACK:102400000,102400000")
inline int sgn(double x)
{
return fabs(x) < EPS ? 0 : (x < 0 ? -1 : 1);
}
#define N 810
int n;
inline char read()
{
char s = 0, t;
while (t = getchar(), t > 47)
{
s += t - '0';
}
return s % 3;
}
char a

, b

;
int val[3][3];
int t

;
void init()
{
for (int i = 0; i < 3; i++)
{
for (int j = 0; j < 3; j++)
{
val[i][j] = (i * j) % 3;
}
}
}
int main()
{
#ifdef DeBUGs
freopen("C:\\Users\\Sky\\Desktop\\1.in", "r", stdin);
#endif
int i, j, k;
int t;
init();
while (scanf("%d", &n) + 1)
{
getchar();
for (i = 0; i < n; i++)
{
for (j = 0; j < n; j++)
{
a[i][j] = read();
}
}
for (i = 0; i < n; i++)
{
for (j = 0; j < n; j++)
{
b[i][j] = read();
}
}
for (i = 0; i < n; i++)
{
t = 0;
for (k = 0; k < n; k++)
t += val[a[i][k]][b[k][0]];
putchar(t % 3 + '0');
for (j = 1; j < n; j++)
{
t = 0;
for (k = 0; k < n; k++)
t += val[a[i][k]][b[k][j]];
putchar(' ');
putchar(t % 3 + '0');

}
printf("\n");
}
}

return 0;
}


另一份代码

#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;
inline void rd(int &ret)
{
char c;
do
{
c = getchar();
}
while (c < '0' || c > '9');
ret = c - '0';
while ((c = getchar()) >= '0' && c <= '9')
ret = ret * 10 + ( c - '0' );
}
inline void ot(int a)    //输出外挂
{
if (a > 9)
ot(a / 10);
putchar(a % 10 + '0');
}
const int MAX_N = 807;

int n;
int a[MAX_N][MAX_N], b[MAX_N][MAX_N];
int c[MAX_N][MAX_N];

int main()
{
while (1 == scanf("%d", &n))
{
for (int i = 0; i < n; ++i)
{
for (int j = 0; j < n; ++j)
{
int x;
rd(x);
a[i][j] = x % 3;
}
}
for (int i = 0; i < n; ++i)
{
for (int j = 0; j < n; ++j)
{
int x;
rd(x);
b[i][j] = x % 3;
}
}
memset(c, 0, sizeof(c));
//注意这里矩阵乘法优化
for (int i = 0; i < n; ++i)
{
for (int k = 0; k < n; ++k)
{
if (a[i][k] == 0) continue;
for (int j = 0; j < n; ++j)
{
c[i][j] += a[i][k] * b[k][j];
}
}
}
for (int i = 0; i < n; ++i)
{
for (int j = 0; j < n; ++j)
{
if (j == 0) ot(c[i][j] % 3);
else
{
putchar(' ');
ot(c[i][j] % 3);
}
}
puts("");
}
}
return 0;
}


这份代码写完了虽然看着让人想哭,但还是贴这里吧,希望某某年可以用得着

#define DeBUG
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cstdlib>
#include <algorithm>
#include <vector>
#include <stack>
#include <queue>
#include <string>
#include <set>
#include <sstream>
#include <map>
#include <list>
#include <bitset>
using namespace std ;
#define zero {0}
#define INF 0x3f3f3f3f
#define EPS 1e-6
#define TRUE true
#define FALSE false
typedef long long LL;
const double PI = acos(-1.0);
//#pragma comment(linker, "/STACK:102400000,102400000")
inline int sgn(double x)
{
return fabs(x) < EPS ? 0 : (x < 0 ? -1 : 1);
}
#define N 100005
int **A, * *B, * *C;
int mod = 3;
void init(int n)
{
A = new int *
;
B = new int *
;
C = new int *
;
for (int i = 0; i < n; i++)
{
A[i] = new int
;
B[i] = new int
;
C[i] = new int
;
}
}
inline void clear(int n)
{
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
A[i][j] = B[i][j] = C[i][j] = 0;
}
}
}
inline void Divide(int n, int **A, int **A11, int **A12, int **A21, int **A22)
{
int i, j;
for (i = 0; i < n; i++)
for (j = 0; j < n; j++)
{
A11[i][j] = A[i][j];
A12[i][j] = A[i][j + n];
A21[i][j] = A[i + n][j];
A22[i][j] = A[i + n][j + n];
}

}
inline void Merge(int n, int **A, int **A11, int **A12, int **A21, int **A22)
{
int i, j;
for (i = 0; i < n; i++)
for (j = 0; j < n; j++)
{
A[i][j] = A11[i][j];
A[i][j + n] = A12[i][j];
A[i + n][j] = A21[i][j];
A[i + n][j + n] = A22[i][j];
}
}
inline void Sub(int n, int **A, int **B, int **C)
{
int i, j;
for (i = 0; i < n; i++)
for (j = 0; j < n; j++)
C[i][j] = (A[i][j] - B[i][j]) % mod ;
}
inline void Add(int n, int **A, int **B, int **C)
{
int i, j;
for (i = 0; i < n; i++)
for (j = 0; j < n; j++)
C[i][j] = (A[i][j] + B[i][j]) % mod;
}
inline void freeit(int **A, int n)
{
for (int i = 0; i < n; i++)
{
delete []A[i];
}
}
inline int read()
{
char s = 0, t;
while (t = getchar(), t > 47)
{
s += t - '0';
}
return s % 3;
}
inline void Mutiply(int n, int **A, int **B, int **M)
{
if (n <= 256)
{
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
M[i][j] = 0;
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
for (int k = 0; k < n; k++)
{
M[i][k] += (A[i][j] * B[j][k]) % mod;
}
}
}
}
else
{
n = n / 2;
int **A11, **A12, **A21, **A22;
int **B11, **B12, **B21, **B22;
int **M11, **M12, **M21, **M22;
int **M1, **M2, **M3, **M4, **M5, **M6, **M7;
int **T1, **T2;
A11 = new int *
;
A12 = new int *
;
A21 = new int *
;
A22 = new int *
;

B11 = new int *
;
B12 = new int *
;
B21 = new int *
;
B22 = new int *
;

M11 = new int *
;
M12 = new int *
;
M21 = new int *
;
M22 = new int *
;

M1 = new int *
;
M2 = new int *
;
M3 = new int *
;
M4 = new int *
;
M5 = new int *
;
M6 = new int *
;
M7 = new int *
;

T1 = new int *
;
T2 = new int *
;

int i;
for (i = 0; i < n; i++)
{
A11[i] = new int
;
A12[i] = new int
;
A21[i] = new int
;
A22[i] = new int
;
B11[i] = new int
;
B12[i] = new int
;
B21[i] = new int
;
B22[i] = new int
;
M11[i] = new int
;
M12[i] = new int
;
M21[i] = new int
;
M22[i] = new int
;
M1[i] = new int
;
M2[i] = new int
;
M3[i] = new int
;
M4[i] = new int
;
M5[i] = new int
;
M6[i] = new int
;
M7[i] = new int
;

T1[i] = new int
;
T2[i] = new int
;

}
Divide(n, A, A11, A12, A21, A22);
Divide(n, B, B11, B12, B21, B22);

Sub(n, B12, B22, T1);
Mutiply(n, A11, T1, M1);
Add(n, A11, A12, T2);
Mutiply(n, T2, B22, M2);
Add(n, A21, A22, T1);
Mutiply(n, T1, B11, M3);

Sub(n, B21, B11, T1);
Mutiply(n, A22, T1, M4);

Add(n, A11, A22, T1);
Add(n, B11, B22, T2);
Mutiply(n, T1, T2, M5);

Sub(n, A12, A22, T1);
Add(n, B21, B22, T2);
Mutiply(n, T1, T2, M6);

Sub(n, A11, A21, T1);
Add(n, B11, B12, T2);
Mutiply(n, T1, T2, M7);

Add(n, M5, M4, T1);
Sub(n, T1, M2, T2);
Add(n, T2, M6, M11);

Add(n, M1, M2, M12);

Add(n, M3, M4, M21);

Add(n, M5, M1, T1);
Sub(n, T1, M3, T2);
Sub(n, T2, M7, M22);

Merge(n, M, M11, M12, M21, M22);

for (int i = 0; i < n; i++)
{
delete []A11[i];
delete []A12[i];
delete []A21[i];
delete []A22[i];
delete []B11[i];
delete []B12[i];
delete []B21[i];
delete []B22[i];
delete []M11[i];
delete []M12[i];
delete []M21[i];
delete []M22[i];
delete []M1[i];
delete []M2[i];
delete []M3[i];
delete []M4[i];
delete []M5[i];
delete []M6[i];
delete []M7[i];
delete []T1[i];
delete []T2[i];
}
}
}
int main()
{
#ifdef DeBUGs
freopen("C:\\Users\\Sky\\Desktop\\1.in", "r", stdin);
#endif
int n;
init(1024);
while (scanf("%d", &n) + 1)
{
int k = 1;
while (k < n)
{
k <<= 1;
}
clear(k);
getchar();
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
{
// scanf("%d", &A[i][j]);
A[i][j] = read();
}
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
{
// scanf("%d", &B[i][j]);
B[i][j] = read();;
}
n = k;
Mutiply(n, A, B, C);
for (int i = 0; i < n; i++)
{
printf("%d", (C[i][0] + mod) % mod);
for (int j = 1; j < n; j++)
{
printf(" %d", (C[i][j] + mod) % mod);
}
printf("\n");
}

}
// clear(n);
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: