您的位置:首页 > 其它

BZOJ1036 树的统计

2016-05-14 23:28 323 查看

Description

  一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身

Input

  输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有
一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作
的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。

Output

  对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。

Sample Input

4

1 2

2 3

4 1

4 2 1 3

12

QMAX 3 4

QMAX 3 3

QMAX 3 2

QMAX 2 3

QSUM 3 4

QSUM 2 1

CHANGE 1 5

QMAX 3 4

CHANGE 3 6

QMAX 3 4

QMAX 2 4

QSUM 3 4

Sample Output

4

1

2

2

10

6

5

6

5

16

正解:树链剖分+线段树

解题报告:

  维护树上一条路径上的结点权值最大值或和

  没什么好说的,链剖裸题。先树链剖分再根据访问次序建立线段树,用线段树动态维护。

  模板题练手。

//It is made by jump~
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
typedef long long LL;
const int MAXN = 30011;
const int inf = (1<<30);
int n;
int total,ecnt;
int U,VV;
int a[MAXN];
int id[MAXN],pre[MAXN];
int top[MAXN],siz[MAXN],zhongerzi[MAXN],father[MAXN],deep[MAXN];
int next[MAXN*2],to[MAXN*2],first[MAXN];
char ch[8];

struct node{
int l,r;
int _max;int _sum;
}jump[MAXN*4];

void link(int x,int y){ next[++ecnt]=first[x]; first[x]=ecnt; to[ecnt]=y; }

int getint()
{
int w=0,q=0;
char c=getchar();
while((c<'0' || c>'9') && c!='-') c=getchar();
if (c=='-')  q=1, c=getchar();
while (c>='0' && c<='9') w=w*10+c-'0', c=getchar();
return q ? -w : w;
}

void build(int root,int l,int r){
jump[root].l=l;jump[root].r=r;
if(jump[root].l==jump[root].r) {
jump[root]._sum=jump[root]._max=a[ pre[l] ];
return ;
}
int lc=root*2,rc=root*2+1;
int mid=l+(r-l)/2;
build(lc,l,mid); build(rc,mid+1,r);
jump[root]._sum=jump[lc]._sum+jump[rc]._sum;
jump[root]._max=max(jump[lc]._max,jump[rc]._max);
}

void dfs1(int u,int fa){
siz[u]=1;
for(int i=first[u];i;i=next[i]) {
int v=to[i];
if(v!=fa) {
father[v]=u;
deep[v]=deep[u]+1;
dfs1(v,u);
siz[u]+=siz[v];
if(siz[v]>siz[ zhongerzi[u] ]) zhongerzi[u]=v;
}
}
}

void dfs2(int u,int fa){
id[u]=++total; pre[total]=u;
if(zhongerzi[u]) top[zhongerzi[u]]=top[u],dfs2(zhongerzi[u],u);
for(int i=first[u];i;i=next[i]) {
int v=to[i];
if(v==fa || v==zhongerzi[u]) continue;
top[v]=v;
dfs2(v,u);
}
}

int query_sum(int root,int x,int y){
if(jump[root].l>=x && jump[root].r<=y) return jump[root]._sum;
int da=0;
int mid=jump[root].l+(jump[root].r-jump[root].l)/2;
int lc=root*2,rc=root*2+1;
if(x<=mid) da+=query_sum(lc,x,y);
if(y>mid) da+=query_sum(rc,x,y);
return da;
}

int query_max(int root,int x,int y){
if(jump[root].l>=x && jump[root].r<=y) return jump[root]._max;
int da=-inf;
int mid=jump[root].l+(jump[root].r-jump[root].l)/2;
int lc=root*2,rc=root*2+1;
if(x<=mid) da=max(da,query_max(lc,x,y));
if(y>mid) da=max(da,query_max(rc,x,y));
return da;
}

int find_max(int x,int y){
int f1=top[x],f2=top[y];
int daan=-inf;
while(f1!=f2){
if(deep[f1]<deep[f2]) swap(f1,f2),swap(x,y);
daan=max(daan,query_max(1,id[f1],id[x]));
x=father[f1];
f1=top[x];
}
if(deep[x]<deep[y]) swap(x,y);
daan=max(daan,query_max(1,id[y],id[x]));
return daan;
}

int find_sum(int x,int y){
int f1=top[x],f2=top[y];
int daan=0;
while(f1!=f2){
if(deep[f1]<deep[f2]) swap(f1,f2),swap(x,y);
daan+=query_sum(1,id[f1],id[x]);
x=father[f1]; f1=top[x];
}
if(deep[x]<deep[y]) swap(x,y);
daan+=query_sum(1,id[y],id[x]);
return daan;
}

void update(int root,int o,int add){
if(jump[root].l==jump[root].r){
jump[root]._sum+=add;
jump[root]._max+=add;return ;
}
int lc=root*2,rc=root*2+1;
int mid=jump[root].l+(jump[root].r-jump[root].l)/2;
if(o<=mid) update(lc,o,add); else update(rc,o,add);
jump[root]._sum=jump[lc]._sum+jump[rc]._sum;
jump[root]._max=max(jump[lc]._max,jump[rc]._max);
}

int main()
{
n=getint();
int x,y;
for(int i=1;i<n;i++){
x=getint();y=getint();
next[++ecnt]=first[x]; first[x]=ecnt; to[ecnt]=y;
next[++ecnt]=first[y]; first[y]=ecnt; to[ecnt]=x;
}

deep[1]=1;  dfs1(1,0);
top[1]=1;  dfs2(1,0);

for(int i=1;i<=n;i++) a[i]=getint();
build(1,1,n);
int Q=getint();

for(int i=1;i<=Q;i++){
scanf("%s",ch);
if(ch[1]=='M'){
printf("%d\n",find_max(x,y));
}
else if(ch[1]=='S'){
x=getint();y=getint();
printf("%d\n",find_sum(x,y));
}
else{
U=getint();VV=getint();
update(1,id[U],VV-a[U]);a[U]=VV;
}
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: