您的位置:首页 > 其它

Code[VS] 3123 高精度练习之超大整数乘法

2016-12-21 17:35 435 查看
FFT 做 高精度乘法

#include <bits/stdc++.h>

const double pi = acos(-1);

struct complex
{
double a, b;

inline complex(
double _a = 0,
double _b = 0)
{
a = _a;
b = _b;
}

inline friend complex operator +
(const complex &a, const complex &b)
{
return complex(a.a + b.a, a.b + b.b);
}

inline friend complex operator -
(const complex &a, const complex &b)
{
return complex(a.a - b.a, a.b - b.b);
}

inline friend complex operator *
(const complex &a, const complex &b)
{
return complex(a.a*b.a - a.b*b.b, a.a*b.b + a.b*b.a);
}

inline friend complex & operator +=
(complex &a, const complex &b)
{
return a = a+b;
}

inline friend complex & operator -=
(complex &a, const complex &b)
{
return a = a-b;
}

inline friend complex & operator *=
(complex &a, const complex &b)
{
return a = a*b;
}
};

inline complex alpha(double a)
{
return complex(cos(a), sin(a));
}

typedef std::vector<complex> vec;

vec FFT(const vec &a)
{
int n = a.size();

if (n == 1)return a;

complex w_k(1, 0);
complex w_n = alpha(pi*2/n);

vec ar[2], yr[2], y(n);

for (int i = 0; i < n; ++i)
ar[i & 1].push_back(a[i]);

for (int i = 0; i < 2; ++i)
yr[i] = FFT(ar[i]);

for (int i = 0; i < n/2; ++i, w_k *= w_n)
{
y[i] = yr[0][i] + w_k * yr[1][i];
y[i + n/2] = yr[0][i] - w_k * yr[1][i];
}

return y;
}

vec IFFT(const vec &a)
{
int n = a.size();

if (n == 1)return a;

complex w_k(1, 0);
complex w_n = alpha(-pi*2/n);

vec ar[2], yr[2], y(n);

for (int i = 0; i < n; ++i)
ar[i & 1].push_back(a[i]);

for (int i = 0; i < 2; ++i)
yr[i] = IFFT(ar[i]);

for (int i = 0; i < n/2; ++i, w_k *= w_n)
{
y[i] = yr[0][i] + w_k * yr[1][i];
y[i + n/2] = yr[0][i] - w_k * yr[1][i];
}

return y;
}

char s1[100005]; int len1;
char s2[100005]; int len2;

vec v1, v2, p1, p2, mul, ans;

signed main(void)
{
scanf("%s", s1); len1 = strlen(s1);
scanf("%s", s2); len2 = strlen(s2);

int len = len1 + len2;

while (len != (len&-len))++len;

for (int i = len1 - 1; ~i; --i)v1.push_back(complex(s1[i] - '0', 0));
for (int i = len2 - 1; ~i; --i)v2.push_back(complex(s2[i] - '0', 0));

while ((int)v1.size() < len)v1.push_back(complex());
while ((int)v2.size() < len)v2.push_back(complex());

p1 = FFT(v1);
p2 = FFT(v2);

for (int i = 0; i < len; ++i)
mul.push_back(p1[i] * p2[i]);

ans = IFFT(mul);

std::vector<int> ret;

for (int i = 0; i < len; ++i)
ret.push_back((int)round(ans[i].a / len));

for (int i = 0; i < len; ++i)
if (ret[i] >= 10)
{
ret[i + 1] += ret[i] / 10;
ret[i] %= 10;
}

while (ret.size() != 1 && !ret[ret.size() - 1])
ret.pop_back();

for (int i = ret.size() - 1; i >= 0; --i)
putchar('0' + ret[i]);
putchar('\n');
}


@Author: YouSiki
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: