您的位置:首页 > 其它

[UOJ34]FFT && NTT 模板

2015-07-03 15:50 302 查看
到底为什么这么慢啊

FFT

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<cmath>
#define SF scanf
#define PF printf
using namespace std;
typedef long long LL;
inline int read() {
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9') { if(ch == '-') f = -1; ch = getchar(); }
	while(ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
	return x * f;
}
const int MAXN = 200000;
const double PI = acos(-1.0);
struct cpx {
	double re, im;
	cpx () {}
	cpx (double a, double b) : re(a), im(b) {}
	cpx operator + (const cpx &t) const {
		return cpx(re+t.re, im+t.im);
	}
	cpx operator - (const cpx &t) const {
		return cpx(re-t.re, im-t.im);
	}
	cpx operator * (const cpx &t) const {
		return cpx(re*t.re - im*t.im, re * t.im + im * t.re);
	}
} A[MAXN*4+10], B[MAXN*4+10], fa[MAXN*4+10], fb[MAXN*4+10];
int n, m, N;
int rev[MAXN*4+10];
void fft_init(int n) {
	N = 1; int lg = 0;
	while(N < n) N <<= 1, lg++;
	
	for(int i = 1; i <= N; i++)
		for(int j = 1, ii = i; j <= lg; j++, ii >>= 1)
			rev[i] = (rev[i]<<1) + (ii & 1);
}
cpx EXP(double alp) {
	return cpx(cos(alp), sin(alp));
}
void fft(cpx *a, cpx *out, int sig) {
	static cpx tmp[MAXN*4+10];
	for(int i = 0; i < N; i++) tmp[rev[i]] = a[i];
	for(int step = 1; step < N; step <<= 1) {
		int dou = step << 1;
		for(int i = 0; i < step; i++) {
			cpx wi = EXP(sig * PI * i / step);
			for(int k = i; k < N; k += dou) {
				int kk = k + step;
				cpx u = tmp[k], v = tmp[kk] * wi;
				tmp[k] = u+v;
				tmp[kk] = u-v;
			}
		}
	}
	for(int i = 0; i < N; i++) out[i] = tmp[i];
	if(sig == -1)
		for(int i = 0; i < N; i++) out[i].re /= N;
}
int main() {
	n = read(); m = read();
	n++; m++;
	for(int i = 0; i < n; i++) A[i].re = read();
	for(int i = 0; i < m; i++) B[i].re = read();
	fft_init((n+m-1) * 2);
	fft(A, fa, 1);
	fft(B, fb, 1);
	for(int i = 0; i < N; i++) fa[i] = fa[i] * fb[i];
	fft(fa, fb, -1);
	for(int i = 0; i < n + m - 1; i++) 
		PF("%d ", (int)(fb[i].re + 0.5));
}


NTT

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<cmath>
#define SF scanf
#define PF printf
using namespace std;
typedef long long LL;
inline int read() {
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9') { if(ch == '-') f = -1; ch = getchar(); }
	while(ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
	return x * f;
}
const int MAXN = 400000;
const int MOD = 998244353;
const int G0 = 3;
int A[MAXN*2+10], B[MAXN*2+10], fa[MAXN*2+10], fb[MAXN*2+10];
int G[MAXN*2+10], rev[MAXN*2+10];
int n, m, N, inv;
int pow_mod(int x, int k) {
	int ret = 1;
	while(k) {
		if(k & 1) ret = 1LL * ret * x % MOD;
		x = 1LL * x * x % MOD;
		k >>= 1;
	}
	return ret;
}
void fft_init(int n) {
	N = 1; int lg = 0;
	while(N < n) N<<=1, lg++;
	
	int g = pow_mod(G0, (MOD-1) / N);
	inv = pow_mod(N, MOD-2);
	for(int i = G[0] = 1; i <= N; i++) G[i] = 1LL * G[i-1] * g % MOD;
	for(int i = 1; i <= N; i++)
		for(int j = 1, ii = i; j <= lg; j++, ii >>= 1)
			rev[i] = (rev[i] << 1) + (ii & 1);
}
void fft(int *out, int *a, int sig) {
	static int tmp[MAXN*2+10];
	for(int i = 0; i < N; i++) tmp[rev[i]] = a[i];
	for(int step = 1; step < N; step <<= 1) {
		int dou = step << 1;
		for(int i = 0; i < step; i++) {
			int wi = sig > 0 ? G[ i * (N / dou) ] : G[ N - i * (N / dou) ];
			for(int k = i; k < N; k += dou) {
				int kk = k+step;
				int u = tmp[k], v = 1LL * tmp[kk] * wi % MOD;
				tmp[k] = (u+v) % MOD;
				tmp[kk] = ((u-v) % MOD+MOD) % MOD;
			}
		}
	}
	for(int i = 0; i < N; i++) out[i] = tmp[i];
}
int main() {
	n = read(); m = read();
	n++; m++;
	for(int i = 0; i < n; i++) A[i] = read();
	for(int i = 0; i < m; i++) B[i] = read();
	fft_init((n+m-1) * 2);
	fft(fa, A, 1);
	fft(fb, B, 1);
	for(int i = 0; i < N; i++) fa[i] = fa[i] * 1LL * fb[i] % MOD * inv % MOD;
	fft(fb, fa, -1);
	for(int i = 0; i < n+m-1; i++) PF("%d ", fb[i]);
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: