您的位置:首页 > 其它

[uoj 34 多项式乘法] FFT&NTT 模板

2017-04-01 00:24 549 查看

[uoj 34 多项式乘法] FFT&NTT 模板

分类:
模板
FFT
NTT


1. 题目链接

[uoj 34 多项式乘法]

2. 题意描述

给你两个多项式,请输出乘起来后的多项式。

第一行两个整数 n 和 m,分别表示两个多项式的次数。

第二行 n+1个整数,分别表示第一个多项式的 0 到 n 次项前的系数。

第三行 m+1 个整数,分别表示第一个多项式的 0 到 m 次项前的系数。

3. 解题思路

模板测试题。给出FFT和NTT的板子。

可以直接去[uoj statistics] 查看更好的板子。

4. 实现代码

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;
typedef long double LB;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;
typedef vector<int> VII;

const int INF = 0x3f3f3f3f;
const LL INFL = 0x3f3f3f3f3f3f3f3fLL;
const double eps = 1e-8;
const double PI = acos(-1.0);

template <typename T>
inline bool scan_d (T &ret) {
char c;
int sgn;
if (c = getchar(), c == EOF) return 0; //EOF
while (c != '-' && (c < '0' || c > '9') ) c = getchar();
sgn = (c == '-') ? -1 : 1;
ret = (c == '-') ? 0 : (c - '0');
while (c = getchar(), c >= '0' && c <= '9') ret = ret * 10 + (c - '0');
ret *= sgn;
return 1;
}
template<typename T>
void print(T x) {
static char s[33], *s1; s1 = s;
if (!x) *s1++ = '0';
if (x < 0) putchar('-'), x = -x;
while(x) *s1++ = (x % 10 + '0'), x /= 10;
while(s1-- != s) putchar(*s1);
}
template<typename T> void println(T x) { print(x); putchar('\n');}

const int MAXN = 262144 + 5;    /// 数组大小应为2^k
//typedef complex<double> CP;
struct CP {
double x, y;
CP() {}
CP(double x, double y) : x(x), y(y) {}
inline double real() { return x; }
inline CP operator * (const CP& r) const { return CP(x * r.x - y * r.y, x * r.y + y * r.x); }
inline CP operator - (const CP& r) const { return CP(x - r.x, y - r.y); }
inline CP operator + (const CP& r) const { return CP(x + r.x, y + r.y); }
};
CP a[MAXN], b[MAXN];
int r[MAXN], res[MAXN];

void fft_init(int nm, int k) {
for(int i = 0; i < nm; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (k - 1)); /// Rader操作
}

void fft(CP ax[], int nm, int op) {
for(int i = 0; i < nm; ++i) if(i < r[i]) swap(ax[i], ax[r[i]]);
for(int h = 2, m = 1; h <= nm; h <<= 1, m <<= 1) {  /// 枚举长度
CP wn = CP(cos(op * 2 * PI / h), sin(op * 2 * PI / h));
for(int i = 0; i < nm; i += h) {    /// 枚举所有长度为h的区间
CP w(1, 0);                     /// 旋转因子
for(int j = i; j < i + m; ++j, w = w * wn) { /// 枚举角度
CP t = w * ax[j + m];       /// 蝴蝶操作
ax[j + m] = ax[j] - t;
ax[j] = ax[j] + t;
}
}
}
if(op == -1) for(int i = 0; i < nm; ++i) ax[i].x /= nm;
}

void trans(int ax[], int bx[], int n, int m) {

int nm = 1, k = 0;
while(nm < 2 * n || nm < 2 * m) nm <<= 1, ++k;

for(int i = 0; i < n; ++i) a[i] = CP(ax[i], 0);
for(int i = 0; i < m; ++i) b[i] = CP(bx[i], 0);
for(int i = n; i < nm; ++i) a[i] = CP(0, 0);
for(int i = m; i < nm; ++i) b[i] = CP(0, 0);

fft_init(nm, k);
fft(a, nm, 1); fft(b, nm, 1);
for(int i = 0; i < nm; ++i) a[i] = a[i] * b[i];
fft(a, nm, -1);
nm = n + m - 1;
for(int i = 0; i < nm; ++i)
res[i] = (int)(a[i].real() + 0.5), print(res[i]), putchar(" \n"[i == nm - 1]);
}

int main() {
#ifdef ___LOCAL_WONZY___
freopen("input.txt", "r", stdin);
#endif // ___LOCAL_WONZY___
static int ax[MAXN], bx[MAXN], n, m;

scan_d(n); scan_d(m); ++n, ++m;
for(int i = 0; i < n; ++i) scan_d(ax[i]);
for(int i = 0; i < m; ++i) scan_d(bx[i]);

trans(ax, bx, n, m);

return 0;
}


/** NTT **/
#include <bits/stdc++.h>
using namespace std;

typedef long long LL;
typedef long double LB;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;
typedef vector<int> VII;

const int INF = 0x3f3f3f3f;
const LL INFL = 0x3f3f3f3f3f3f3f3fLL;
const double eps = 1e-8;
const double PI = acos(-1.0);

template <typename T>
inline bool scan_d (T &ret) {
char c;
int sgn;
if (c = getchar(), c == EOF) return 0; //EOF
while (c != '-' && (c < '0' || c > '9') ) c = getchar();
sgn = (c == '-') ? -1 : 1;
ret = (c == '-') ? 0 : (c - '0');
while (c = getchar(), c >= '0' && c <= '9') ret = ret * 10 + (c - '0');
ret *= sgn;
return 1;
}
template<typename T>
void print(T x) {
static char s[33], *s1; s1 = s;
if (!x) *s1++ = '0';
if (x < 0) putchar('-'), x = -x;
while(x) *s1++ = (x % 10 + '0'), x /= 10;
while(s1-- != s) putchar(*s1);
}
template<typename T> void println(T x) { print(x); putchar('\n');}

const int MAXN = 262144 + 5;    /// 数组大小应为2^k
const int G = 3, MOD = 998244353;

int a[MAXN], b[MAXN], r[MAXN], res[MAXN];

template<typename T>
T quick_pow(T a, T b) {
T ret = 1;
while(b) {
if(b & 1) ret = (LL)ret * a % MOD;
a = (LL)a * a % MOD;
b >>= 1;
}
return ret;
}

void ntt_init(int nm, int k) {
for(int i = 0; i < nm; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (k - 1)); /// Rader操作
}

template<typename T>
void ntt(T ax[], int nm, int op) {
for(int i = 0; i < nm; ++i) if(i < r[i]) swap(ax[i], ax[r[i]]);
for(int h = 2, m = 1; h <= nm; h <<= 1, m <<= 1) {  /// 枚举长度
T wn = quick_pow(G, (MOD - 1) / h);
for(int i = 0; i < nm; i += h) {    /// 枚举所有长度为h的区间
T w = 1;                                  /// 旋转因子
for(int j = i; j < i + m; ++j, w = (LL)w * wn % MOD) { /// 枚举角度
T t = (LL)w * ax[j + m] % MOD;       /// 蝴蝶操作
ax[j + m] = ax[j] - t + MOD;
if(ax[j + m] >= MOD) ax[j + m] -= MOD;
ax[j] = ax[j] + t;
if(ax[j] >= MOD) ax[j] -= MOD;
}
}
}
if(op == -1) {
for(int i = 1; i < nm / 2; i++) swap(ax[i], ax[nm - i]); /// Caution Here!
T inv = quick_pow(nm, MOD - 2);
for(int i = 0; i < nm; ++i) ax[i] = (LL)ax[i] * inv % MOD;
}
}

template<typename T>
void trans(T ax[], T bx[], int n, int m) {

int nm = 1, k = 0;
while(nm < 2 * n || nm < 2 * m) nm <<= 1, ++k;

for(int i = 0; i < n; ++i) a[i] = ax[i];
for(int i = 0; i < m; ++i) b[i] = bx[i];
for(int i = n; i < nm; ++i) a[i] = 0;
for(int i = m; i < nm; ++i) b[i] = 0;

ntt_init(nm, k);
ntt(a, nm, 1); ntt(b, nm, 1);
for(int i = 0; i < nm; ++i) a[i] = (LL)a[i] * b[i] % MOD;
ntt(a, nm, -1);
nm = n + m - 1;
for(int i = 0; i < nm; ++i) res[i] = a[i], print(res[i]), putchar(" \n"[i == nm - 1]);
}

int main() {
#ifdef ___LOCAL_WONZY___
freopen("input.txt", "r", stdin);
#endif // ___LOCAL_WONZY___
static int n, m;

scan_d(n); scan_d(m); ++n, ++m;
for(int i = 0; i < n; ++i) scan_d(a[i]);
for(int i = 0; i < m; ++i) scan_d(b[i]);

trans(a, b, n, m);

return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: