您的位置:首页 > 其它

Uoj 33 树上GCD (树分治)

2016-07-29 17:06 323 查看
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <queue>
#include <vector>
using namespace std;

#define N 300020
#define B 300
#define M 600200
#define inf 0x3f3f3f3f
#define mod 1000000007
#define LL long long
#define ls (i << 1)
#define rs (ls | 1)
#define md (ll + rr >> 1)
#define lson ll, md, ls
#define rson md + 1, rr, rs
#define MP make_pair
#define ui unsigned int

int n, fa
;
int fst
, nxt[M], vv[M], e;
int q
, qh, qt;
int dep
, sz
, mx
;
LL ans1
, ans2
;
bool vis
;
int rt;
LL x
, y
, xx
, yy
;
LL F[B + 10][B + 10];

void init() {
memset(fst, -1, sizeof fst);
e = 0;
}
void add(int u, int v) {
vv[e] = v, nxt[e] = fst[u], fst[u] = e++;
}

void bfs(int s, int fd) {
dep[s] = fd;
qh = qt = 0;
q[qt++] = s;
while(qh < qt) {
int u = q[qh++];
for(int i = fst[u]; ~i; i = nxt[i]) {
int v = vv[i];
if(vis[v]) continue;
dep[v] = dep[u] + 1;
q[qt++] = v;
}
}
}

void get_rt(int u) {
bfs(u, 0);
for(int i = qt - 1; i >= 0; --i) {
int u = q[i];
sz[u] = 1;
for(int j = fst[u]; ~j; j = nxt[j]) {
int v = vv[j];
if(vis[v]) continue;
sz[u] += sz[v];
}
}
int tmp = n + 1;
for(int i = 0; i < qt; ++i) {
int u = q[i];
mx[u] = 0;
for(int j = fst[u]; ~j; j = nxt[j]) {
int v = vv[j];
if(vis[v]) continue;
mx[u] = max(mx[u], sz[v]);
}
mx[u] = max(mx[u], qt - sz[u]);
if(mx[u] < tmp) tmp = mx[u], rt = u;
}
}

void divide(int s) {
get_rt(s);
vis[rt] = 1;
int rrt = rt;
int tp = 0;

for(int i = fst[rt]; ~i; i = nxt[i]) {
int v = vv[i];
if(vis[v]) continue;
bfs(v, 1);
tp = max(tp, qt);
for(int j = 0; j < qt; ++j) {
int u = q[j];
y[dep[u]]++;
ans2[dep[u]]++;
}
for(int j = 1; j <= qt; ++j) {
x[j] += y[j];
for(int k = j; k <= qt; k += j) {
yy[j] += y[k];
}
}
for(int j = 1; j <= qt; ++j) {
ans1[j] += 1LL * xx[j] * yy[j];
xx[j] += yy[j];
yy[j] = y[j] = 0;
}
}
x[0]++;
int mx_qt = 0;

int t = 1;
for(int u = rt; u != s; ++t) {
u = fa[u];
bfs(u, 0);
mx_qt = max(mx_qt, qt);
for(int i = 1; i < qt; ++i) {
y[dep[q[i]]]++;
}
for(int i = 1; i <= qt; ++i) {
for(int j = i; j <= qt; j += i)

yy[i] += y[j];
}
for(int i = 1; i <= min(qt, B); ++i) {
if(F[i][t%i] == -1) {
F[i][t%i] = 0;
for(int j = (i - t % i) % i; j <= tp; j += i) {
F[i][t%i] += x[j];
}
}
ans1[i] += F[i][t%i] * yy[i];
}
for(int i = B + 1; i <= qt; ++i) {
for(int j = (i - t % i) % i; j <= tp; j += i)
ans1[i] += x[j] * yy[i];
}
for(int i = 1; i <= qt; ++i) y[i] = yy[i] = 0;
vis[u] = 1;
}
for(int i = 1; i <= min(mx_qt, B); ++i) {
for(int j = 0; j < i; ++j)
F[i][j] = -1;
}
for(int u = rt; u != s;) {
u = fa[u];
vis[u] = 0;
}

--t;
LL tmp = 0;
for(int k = 1; k <= tp + t; ++k) {
if(k - 1 <= tp) tmp += x[k-1];
if(k - t - 1 >= 0 && k - t - 1 <= tp) tmp -= x[k-t-1];
ans2[k] += tmp;
}
for(int i = 0; i <= tp; ++i) x[i] = xx[i] = 0;
if(rrt != s) divide(s);
for(int i = fst[rrt]; ~i; i = nxt[i]) {
int v = vv[i];
if(!vis[v]) {
divide(v);
}
}
}
int main() {
scanf("%d", &n);
init();
for(int i = 2; i <= n; ++i) {
scanf("%d", &fa[i]);
add(fa[i], i);
}
memset(F, -1, sizeof F);
divide(1);
for(int i = n - 1; i >= 1; --i) {
for(int j = i + i; j <= n - 1; j += i) {
ans1[i] -= ans1[j];
}
}
for(int i = 1; i <= n - 1; ++i) {
printf("%lld\n", ans1[i] + ans2[i]);
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: