您的位置:首页 > 理论基础 > 数据结构算法

2016计蒜客初赛第一场 青云的机房组网方案(困难):图论+虚树+容斥

2016-06-10 08:56 507 查看
题目链接:http://nanti.jisuanke.com/t/11135

题目概述:给一棵n个节点的树,每个节点有一个初始值ai。1<=n<=100000,1<=ai<=100000。求树上任意两个值互质点距离的和。

思路概述:

枚举 因数x,x是每种质因子至多有一个的数,记录一下x有几种质因子,方便之后容斥。
把所有x的倍数的权值的点找出来,预处理下可以做到找出来的点的dfs序是从小到大的,预处理也可以使得每次找x的倍数的权值的点不必线性扫一遍。
然后对这些点 O(n) 建虚树,具体操作是相邻两个点加进去 lca,用一个栈维护下父亲链即可。[bzoj3572]是一道典型的虚树的题目。
构建好树后在树上 dfs 两次可以求出所有x的倍数的权值的点对之间的距离和,就是第一遍dfs记录以节点u为根的子树中,有多少个x倍数的点(可能有一些是虚树添加进来的lca点),第二遍dfs其实是枚举每条边,计算(u,v)这条边的总价值,就是它出现的次数乘以它的权值;它出现的次数就是它子树中x倍数的点的个数,乘以不在它子树中x倍数的点的个数。
最后容斥下就可以求出答案。

由于所有步骤均是线性的,而所有虚树加起来的总点数也是线性乘上一个常数的,所以复杂度为 O(nK),K<=128。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
const int maxp = 316;
const int maxn = 100100;
const int maxl = 17;
bool isprime[maxp+5];
int prime[maxp+5],pnum = 0;
int a[maxn];
int anc[maxn][maxl+1],dep[maxn],cur[maxn],Stack[maxn<<1],dfn[maxn];
int n,dfs_seq=0;
vector<int> arr[maxn];
vector<int> factor[maxn];
int rongchi[maxn];
struct Node
{
int head[maxn],nex[maxn<<1],point[maxn<<1],weight[maxn<<1],siz[maxn];
bool selected[maxn];
map<int,int> label;
int ne,total,nl;
ll ans;
void init(int tot)
{
label.clear();
total = tot;
ne = 0;
ans = 0;
nl = 0;
}
void addedge(int u,int v, int w)
{
if(label.count(u)) u = label[u];
else
{
label[u] = ++nl;
u = nl;
head[u] = -1;
selected[u] = false;
}
if(label.count(v)) v = label[v];
else
{
label[v] = ++nl;
v = nl;
head[v] = -1;
selected[v] = false;
}
point[ne] = v;
nex[ne] = head[u];
weight[ne] = w;
head[u] = ne++;
point[ne] = u;
nex[ne] = head[v];
weight[ne] = w;
head[v] = ne++;
}
void set(int x)
{
if(!label.count(x))
{
label[x] = ++nl;
x = nl;
}
else x = label[x];
head[x] = -1;
selected[x] = true;
}
void dfs1(int root,int fa)
{
siz[root] = selected[root]?1:0;
for(int i=head[root]; i!=-1; i=nex[i])
{
if(point[i] == fa) continue;
dfs1(point[i],root);
siz[root] += siz[point[i]];
}
}
void dfs2(int root,int fa)
{
for(int i=head[root]; i!=-1; i=nex[i])
{
if(point[i]==fa) continue;
ans += 1LL*weight[i]*siz[point[i]]*(total-siz[point[i]]);
dfs2(point[i],root);
}
}
} g1,g2;

void dfs(int root)
{
int top = 0;
dep[root] = 1;
for(int i=0; i<=maxl; i++)
{
anc[root][i] = root;
}
Stack[++top] = root;
memcpy(cur,g1.head,sizeof(cur));
while(top)
{
int x = Stack[top];
if(x != root)
{
for(int i=1; i<=maxl; i++)
{
int y = anc[x][i-1];
anc[x][i] = anc[y][i-1];
}
}
for(int &i = cur[x]; i!= -1; i=g1.nex[i])
{
int y = g1.point[i];
if(y != anc[x][0])
{
dep[y] = dep[x]+1;
anc[y][0] = x;
Stack[++top] = y;
}
}
while(top && cur[Stack[top]] == -1) top--;
}
}
void swim(int &x,int H)
{
for(int i=0; H>0; i++)
{
if(H&1) x = anc[x][i];
H >>= 1;
}
}
int lca(int x,int y)
{
int i;
if(dep[x] > dep[y]) swap(x,y);
swim(y,dep[y]-dep[x]);
if(x == y) return x;
while(true)
{
for(i=0; anc[x][i] != anc[y][i]; i++);
if(i == 0) return anc[x][0];
x = anc[x][i-1];
y = anc[y][i-1];
}
return -1;
}
void getfactor(int pointer,int cur,int acc,const vector<pii> &tmp,int num)
{
if(pointer >= tmp.size()){
if(cur > 1) factor[num].push_back(cur);
return;
}
if(acc == 0) getfactor(pointer,cur*tmp[pointer].first,acc+1,tmp,num);
getfactor(pointer+1,cur,0,tmp,num);
}
void init(int ma)
{
memset(isprime,true,sizeof(isprime));
memset(rongchi,0,sizeof(rongchi));
int maxpp = (int)sqrt(ma+1.0);
for(int i=2; i<=maxpp; i++)
{
if(isprime[i]) prime[pnum++] = i;
for(int j=0; j<pnum; j++)
{
if(i*prime[j] > maxpp) break;
isprime[i*prime[j]] = false;
if(i%prime[j] == 0) break;
}
}
vector<pii> tmp;
for(int i=2; i<=ma; i++)
{
tmp.clear();
int ti = i;
for(int j=0; j<pnum && ti > 1; j++)
{
int cnt = 0;
while(ti%prime[j] == 0)
{
ti /= prime[j];
cnt++;
rongchi[i]++;
}
if(cnt) tmp.push_back(make_pair(prime[j],cnt));
}
if(ti > 1) tmp.push_back(make_pair(ti,1)),rongchi[i]++;
getfactor(0,1,0,tmp,i);
}
}
void dfs_dfn(int index)
{
dfn[index] = dfs_seq++;
int fsiz = factor[a[index]].size();
for(int i=0; i<fsiz; i++)
{
arr[factor[a[index]][i]].push_back(index);
}
for(int i=g1.head[index]; i!=-1; i=g1.nex[i])
{
if(g1.point[i] == anc[index][0]) continue;
dfs_dfn(g1.point[i]);
}
}
void work(int index)
{
int top=0;
int cnt = arr[index].size();
g2.init(cnt);
for (int i=0; i<cnt; i++)
{
g2.set(arr[index][i]);
if (!top)
{
Stack[++top]=arr[index][i];
continue;
}
int u=lca(Stack[top],arr[index][i]);
while (dfn[u]<dfn[Stack[top]])
{
if (dfn[u]>=dfn[Stack[top-1]])
{
g2.addedge(u,Stack[top],dep[Stack[top]]-dep[u]);
if (Stack[--top]!=u) Stack[++top]=u;
break;
}
g2.addedge(Stack[top-1],Stack[top],dep[Stack[top]]-dep[Stack[top-1]]),top--;
}
Stack[++top]=arr[index][i];
}
while (top>1) g2.addedge(Stack[top-1],Stack[top],dep[Stack[top]]-dep[Stack[top-1]]),top--;
g2.dfs1(1,0);
g2.dfs2(1,0);
}
int main()
{
int u,v,ma = -1;
scanf("%d",&n);
g1.init(n);
for(int i=1; i<=n; i++) scanf("%d",a+i), g1.set(i), ma = max(ma,a[i]);
for(int i=0; i<n-1; i++)
{
scanf("%d%d",&u,&v);
g1.addedge(u,v,1);
}
init(ma);
dfs(1);
dfs_dfn(1);
g1.dfs1(1,0);
g1.dfs2(1,0);
ll ans = g1.ans;
for(int i=2; i<=ma; i++
9188
)
{
if(arr[i].size())
{
work(i);
if(rongchi[i]&1) ans -= g2.ans;
else ans += g2.ans;
}
}
cout<<ans<<endl;
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息