您的位置:首页 > 其它

fft的迭代实现与ntt模板

2018-02-16 20:57 323 查看

前言:

重看了下fft递归实现,好像不难理解,以前真的太naive。然后被迭代版的各种吊打,赶紧补下,顺便学下ntt。

哈尔滨真的冷,让我这个GD蒟蒻怎么码代码

FFT:

先贴链接:快速傅里叶变换FFT的迭代实现

这篇博客讲的很清楚了,本质上是一样的,就是将底层排好序后,自底向上一层层求。

至于为什么这么排:显然

代码留坑。

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<cmath>
#include<iostream>
using namespace std;
const double pi=acos(-1);
int n,m,bin[400010];
complex <double>a[400010],b[400010],c[400010];
void fft(complex <double> *a,int n,int op)
{
for(int i=0;i<n;i++) if(i<bin[i]) swap(a[i],a[bin[i]]);
for(int i=1;i<n;i<<=1)
{
complex <double> wn(cos(pi/i),op*sin(pi/i)),t;
for(int j=0;j<n;j+=i<<1)
{
complex <double>w(1,0);
for(int k=0;k<i;k++)
{
t=w*a[i+j+k];w*=wn;
a[i+j+k]=a[j+k]-t;a[j+k]=a[j+k]+t;
}
}
}
}
int main()
{
scanf("%d %d",&n,&m);
for(int i=0;i<=n;i++) scanf("%lf",&a[i]);
for(int i=0;i<=m;i++) scanf("%lf",&b[i]);
m+=n;n=1;while(n<=m) n<<=1;
for(int i=0;i<n;i++) bin[i]=(bin[i>>1]>>1)|((i&1)*(n>>1));
fft(a,n,1);fft(b,n,1);
for(int i=0;i<n;i++) c[i]=a[i]*b[i];
fft(c,n,-1);
for(int i=0;i<=m;i++) printf("%d ",(int)(c[i].real()/n+0.5));
}


NTT:

用原根代替单位复根,可以支持modmod操作。

用NTT时要求模数比较特殊也许是我太弱

设模数为p=x∗2N+1(N>=logn)p=x∗2N+1(N>=logn)且是个质数,gg为pp的原根。

因为gp−1=1(mod p)gp−1=1(mod p)设gn=gp−1ngn=gp−1n代替wnwn

显然也是满足wnwn的几条性质的。(相消,折半等)

于是愉快的上代码。

code:

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#define LL long long
using namespace std;
LL a[400010],b[400010],c[400010];
int p=1004535809,g=3,n,m,bin[400010];
LL pow(LL a,int b,int mod)
{
LL ans=1;
while(b)
{
if(b&1) ans=ans*a%mod;
a=a*a%mod;b>>=1;
}
return ans;
}
void ntt(LL *a,int n,int op)
{
for(int i=0;i<n;i++) if(i<bin[i]) swap(a[i],a[bin[i]]);
for(int i=1;i<n;i<<=1)
{
LL wn=pow((LL)g,op==1?(p-1)/(2*i):p-1-(p-1)/(2*i),p),t,w;
for(int j=0;j<n;j+=i<<1)
{
w=1;
for(int k=0;k<i;k++)
{
t=w*a[i+j+k]%p;w=w*wn%p;
a[i+j+k]=(a[j+k]-t+p)%p;a[j+k]=(a[j+k]+t)%p;
}
}
}
if(op==-1)
{
LL inv=pow(n,p-2,p);
for(int i=0;i<n;i++) a[i]=a[i]*inv%p;
}
}
int main()
{
scanf("%d %d",&n,&m);
for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
m+=n;n=1;while(n<=m) n<<=1;
for(int i=0;i<n;i++) bin[i]=(bin[i>>1]>>1)|((i&1)*(n>>1));
ntt(a,n,1);ntt(b,n,1);
for(int i=0;i<n;i++) c[i]=a[i]*b[i];ntt(c,n,-1);
for(int i=0;i<=m;i++) printf("%lld ",c[i]);
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: