您的位置:首页 > 其它

3351: [ioi2009]Regions

2016-02-02 22:15 295 查看

3351: [ioi2009]Regions

Time Limit: 120 Sec  Memory Limit: 128 MB
Submit: 205  Solved: 59

[Submit][Status][Discuss]

Description

 
N个节点的树,有R种属性,每个点属于一种属性。有Q次询问,每次询问r1,r2,回答有多少对(e1,e2)满足e1属性是r1,e2属性是r2,e1是e2的祖先。
数据规模
N≤200000,R≤25000,Q≤200000
30%数据R≤500
55%数据同种属性节点个数≤500

Input

Output

Sample Input

6 3 4

1

1 2

1 3

2 3

2 3

5 1

1 2

1 3

2 3

3 1

Sample Output

1

3

2

1

HINT

Source



[Submit][Status][Discuss]

HOME Back

参考于《根号算法——不只是分块》 王悦同

设属性r1的点有A个属性r2的点有B个

若A很小,则设计一个AlogB的算法,若B很小,则设计一个BlogA的算法

若都很大,这样的询问不会很多,就直接O(N)暴力

针对不同类型询问设计不同的专杀算法,运用根号卡时间

苟蒻编写能力还有待提高!注意数组下标。。。



#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#include<queue>
using namespace std;

const int maxn = 2E5 + 20;

struct Q{
int r1,r2,num;
bool operator < (const Q &b) const {
if (r1 < b.r1) return 1;
if (r1 > b.r1) return 0;
return r2 < b.r2;
}
bool operator == (const Q &b) const {
return (r1 == b.r1 && r2 == b.r2);
}
}q[maxn];

struct P{
int pos,add,sum;
bool operator < (const P &b) const {
return pos < b.pos;
}
};

vector <int> v[maxn],v2[maxn];
vector <P> v3[maxn];

int mark[maxn],dfs[maxn],n,m,ans[maxn],dfs_clock = 0,flag;
int head[maxn],tail[maxn];

int getint()
{
int ret = 0;
char ch = getchar();
while (ch < '0' || ch > '9') ch = getchar();
while ('0' <= ch && ch <= '9') {
ret = ret*10 + ch - '0';
ch = getchar();
}
return ret;
}

void DFS(int k)
{
dfs[k] = head[k] = ++dfs_clock;
v2[mark[k]].push_back(k);
for (int i = 0; i < v[k].size(); i++) DFS(v[k][i]);
tail[k] = dfs_clock;
if (head[k] != tail[k]) {
v3[mark[k]].push_back((P){head[k],1,1});
v3[mark[k]].push_back((P){tail[k]+1,-1,-1});
}
}

int DBSC1(int x,int l,int r,int pos)
{
if (r - l <= 1) {
if (dfs[v2[x][l]] >= pos) return l;
else return r;
}
int mid = (l+r) >> 1;
if (dfs[v2[x][mid]] >= pos) return DBSC1(x,l,mid,pos);
else return DBSC1(x,mid,r,pos);
}

int DBSC2(int x,int l,int r,int pos)
{
if (r - l <= 1) {
if (dfs[v2[x][r]] <= pos) return r;
else return l;
}
int mid = (l+r) >> 1;
if (dfs[v2[x][mid]] <= pos) return DBSC2(x,mid,r,pos);
else return DBSC2(x,l,mid,pos);
}

int DBSC(int x,int l,int r,int pos)
{
if (r - l <= 1) {
if (v3[x][r].pos <= pos) return r;
else return l;
}
int mid = (l+r) >> 1;
if (v3[x][mid].pos <= pos) return DBSC(x,mid,r,pos);
else return DBSC(x,l,mid,pos);
}

void solve1(int x)
{
int ANS = 0;
int r1 = q[x].r1;
int r2 = q[x].r2;
for (int i = 0; i < v2[r1].size(); i++) {
int now = v2[r1][i];
if (head[now] == tail[now]) continue;
if (tail[now] < dfs[v2[r2][0]]) continue;
if (head[now] > dfs[v2[r2][v2[r2].size()-1]]) continue;
int L = DBSC1(r2,0,v2[r2].size()-1,head[now]);
int R = DBSC2(r2,0,v2[r2].size()-1,tail[now]);
ANS += R-L+1;
}
ans[q[x].num] = ANS;
}

void solve2(int x)
{
int ANS = 0;
int r1 = q[x].r1;
int r2 = q[x].r2;
for (int i = 0; i < v2[r2].size(); i++) {
int now = v2[r2][i];
if (dfs[now] < v3[r1][0].pos) continue;
if (dfs[now] > v3[r1][v3[r1].size()-1].pos) continue;
/*for (int i = 0; i < v3[r1].size(); i++) {
int pos = v3[r1][i].pos;
int add = v3[r1][i].add;
int sum = v3[r1][i].sum;
int bbbb = 1;
}*/
int POS = DBSC(r1,0,v3[r1].size()-1,dfs[now]);
ANS += v3[r1][POS].sum;
}
ans[q[x].num] = ANS;
}

void solve3(int x)
{
int ANS,SUM,L,R;
int r1 = q[x].r1;
int r2 = q[x].r2;
ANS = SUM = L = R = 0;

while (L < v3[r1].size() && R < v2[r2].size()) {
if (v3[r1][L].pos <= dfs[v2[r2][R]]) {
SUM += v3[r1][L].add;
L++;
}
else {
ANS += SUM;
R++;
}
}

//while (R < v2[r2].size()) ANS += SUM,R++;
ans[q[x].num] = ANS;
}

int main()
{
#ifdef YZY
freopen("yzy.txt","r",stdin);
#endif

int tt;
cin >> n >> tt >> m >> mark[1];
flag = sqrt(n);

for (int i = 2; i <= n; i++) {
int x,y;
x = getint(); y = getint();
v[x].push_back(i); mark[i] = y;
}
DFS(1);

/*for (int i = 1; i <= tt; i++) {
for (int j = 0; j < v2[i].size(); j++) {
int kk = v2[i][j];
int b = 1;
}
for (int j = 0; j < v3[i].size(); j++) {
int pos = v3[i][j].pos;
int add = v3[i][j].add;
int sum = v3[i][j].sum;
int b = 1;
}
}*/

for (int i = 1; i <= tt; i++) {
sort(v3[i].begin(),v3[i].end());
for (int j = 1; j < v3[i].size(); j++)
v3[i][j].sum += v3[i][j-1].sum;
}

for (int i = 1; i <= m; i++) {
int x,y;
x = getint(); y = getint();
q[i] = (Q){x,y,i};
}
sort(q+1,q+m+1);

for (int i = 1; i <= m; i++) {
if (q[i] == q[i-1]) {
ans[q[i].num] = ans[q[i-1].num];
continue;
}
int S1 = v2[q[i].r1].size();
int S2 = v2[q[i].r2].size();
if (S1 <= flag && S2 <= flag) {
solve3(i);
continue;
}
if (S1 <= flag) {
solve1(i);
continue;
}
if (S2 <= flag) {
solve2(i);
continue;
}
solve3(i);
}

for (int i = 1; i <= m; i++) printf("%d\n",ans[i]);

return 0;
}


内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: