您的位置:首页 > 其它

UOJ272 [清华集训2016] 石家庄的工人阶级队伍比较坚强 【分治乘法】

2018-08-11 08:07 246 查看

题目分析:

首先不难注意到式子就是异或卷积,所以考虑用分治乘法推出优化方法。
我们把一个整体$f$拆成$f-,f\pm,f+$,然后另一个拆成$g-,g\pm,g+$.这样做的好处是能更清楚的分析问题。下面我们下宽油(大雾)。
发现三个部分要求的式子是在两者相乘中选不同的三个,所以我们发现三个部分中每取一个有相同。这样我们聚焦到$--,-\pm,-+$三个东西。观察二进制FWT,可以假想它们要使用到三次单位根。这样只需要把三个根错开排列就行了。
做分治乘法的时候注意把虚部的$I$记做$\sqrt{3}i$.

代码:

1 #include<bits/stdc++.h>
2 using namespace std;
3
4 const int maxn = 1020000;
5
6 struct cn{int rl,vir;}e[3]; // vir's real meaning is vir*sqrt(3)
7
8 int iv2,iv3;
9 int m,n,t,p,phi;
10 int b3[20],b[20][20];
11 cn val[maxn],f[maxn];
12
13 int W[maxn],L[maxn];
14
15 cn operator +(const cn& alpha,const cn& beta){
16     cn ans = (cn){alpha.rl+beta.rl,alpha.vir+beta.vir};
17     if(ans.rl >= p) ans.rl -= p;
18     if(ans.vir >= p) ans.vir -= p;
19     return ans;
20 }
21 cn operator *(const cn& alpha,const cn& beta){
22     cn ans = (cn){0,0};
23     ans.rl = (1ll*alpha.rl*beta.rl-3ll*alpha.vir*beta.vir)%p;
24     ans.rl += p; if(ans.rl >= p) ans.rl -= p;
25     ans.vir = (1ll*alpha.vir*beta.rl+1ll*alpha.rl*beta.vir)%p;
26     return ans;
27 }
28 cn operator *(const cn& alpha,const int& beta){
29     cn ans=alpha;ans.rl=(1ll*ans.rl*beta)%p;ans.vir=(1ll*ans.vir*beta)%p;
30     return ans;
31 }
32
33 cn fast_pow(cn now,int pw){
34     int bit = 1;cn ans = (cn){1,0},dt = now;
35     while(bit <= pw){
36     if(bit & pw) ans = ans*dt;
37     bit<<=1;dt = dt*dt;
38     }
39     return ans;
40 }
41 int fast_pow(int now,int pw){
42     int bit = 1,ans = 1,dt = now;
43     while(bit <= pw){
44     if(bit & pw) ans = (1ll*ans*dt)%p;
45     bit<<=1;dt = (1ll*dt*dt)%p;
46     }
47     return ans;
48 }
49
50 void read(){
51     scanf("%d%d%d",&m,&t,&p);
52     b3[0] = 1; for(int i=1;i<=m;i++) b3[i] = b3[i-1]*3;
53     n = b3[m];
54     for(int i=0;i<n;i++) scanf("%d",&f[i].rl);
55     for(int i=0;i<=m;i++){
56     for(int j=0;i+j<=m;j++){
57         scanf("%d",&b[i][j]);
58     }
59     }
60     val[0].rl = b[0][0];
61     for(int i=1;i<n;i++){
62     W[i] = W[i/3],L[i] = L[i/3];
63     if(i % 3 == 2) L[i]++;
64     if(i % 3 == 1) W[i]++;
65     val[i].rl = b[W[i]][L[i]];
66     }
67 }
68
69 void multi(int l,int r){
70     if(l == r-1){
71     f[l] = f[l]*fast_pow(val[l],t);
72     }else{
73     int l1 = l+(r-l)/3,l2 = l+2*(r-l)/3,d = l2-l1;
74     for(int i=0;i<d;i++){
75         cn p1 = f[l+i],p2 = f[l1+i],p3 = f[l2+i];
76         f[l+i] = p1+p2+p3;
77         f[l1+i] = p1+e[1]*p2+e[2]*p3;f[l2+i] = p1+e[2]*p2+e[1]*p3;
78         p1 = val[l+i],p2 = val[l1+i],p3 = val[l2+i];
79         val[l+i] = p1+p2+p3;
80         val[l1+i] = p1+e[1]*p2+e[2]*p3;val[l2+i] = p1+e[2]*p2+e[1]*p3;
81     }
82     multi(l,l1); multi(l1,l2); multi(l2,r);
83     for(int i=0;i<d;i++){
84         cn p1 = f[l+i],p2 = f[l1+i],p3 = f[l2+i];
85         f[l+i] = p1+p2+p3;
86         f[l1+i] = p1+e[2]*p2+e[1]*p3;f[l2+i] = p1+e[1]*p2+e[2]*p3;
87         f[l+i]=f[l+i]*iv3;f[l1+i]=f[l1+i]*iv3;f[l2+i]=f[l2+i]*iv3;
88     }
89     }
90 }
91
92 void init(){
93     phi = p;int z = p;
94     for(int i=2;i*i<=p;i++){
95     if(p % i == 0){
96         while(p%i == 0) p /= i;
97         phi = (phi/i)*(i-1);
98     }
99     }
100     if(p != 1) phi = (phi/p)*(p-1); p =z;
101     iv2 = fast_pow(2,phi-1); iv3 = fast_pow(3,phi-1);
102     e[0] = (cn){1,0}; e[1] = (cn){p-iv2,iv2}; e[2] = (cn){p-iv2,p-iv2};
103 }
104
105 void work(){
106     multi(0,n);//[0,n)
107     for(int i=0;i<n;i++) printf("%d\n",f[i].rl);
108 }
109
110 int main(){
111     read();
112     init();
113     work();
114     return 0;
115 }

 

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: