您的位置:首页 > 其它

bzoj2286 消耗战 虚树&树形dp

2016-03-06 20:11 375 查看
学习了一下虚树,讲一下自己的理解。

虚树是这么一个东西,对于需要求答案的点p,只保留对答案有影响的节点,从而减少时间。

现在对于这道题目进行特定的说明。

考虑朴素的dp,显然,dp[i]=min(val[i],Σdp[j](j为i的儿子)),val[i]表示将i和根节点分离的代价。那么这样的时间复杂度为O(N),总时间复杂度O(NM)。

注意题目中有Σk<=500000,因此如果能将一次的时间复杂度减小到O(K)或者O(KlogK),就能通过了。因此,关键是能构造出一颗节点<=O(K)级别的虚树,以及能在O(K)或者O(KlogK)的时间构造出虚树。

定义某一次询问给出的岛屿为关键点。注意到对于某对关键点(x,y),考虑x->lca(x,y)的路径中,没有那个点是某一对关键点的lca,那么显然x->lca(x,y)的路径上的点对答案不会产生任何影响,换句话说将lca(x,y)->...->x的路径直接压缩成lca(x,y)->x,对答案不会产生影响。因此我们只需要保留所有关键点,以及它们两两之间的lca,然后按照原数的祖先关系连边,在构造得到的虚树上面跑dp即可。注意到k个点两两之间不同的lca只有k-1个,因此产生的虚树是O(K)的。

下面来构造,首先按照在原树上的dfs将岛屿进行排序。

用一个栈维护从根节点到栈顶的路径上,需要加入虚树的点形成的一条链。注意不能直接将链中的点连边,因为有可能会有新的关键点的lca也在这条链上。考虑现在新加进来一个点x,以及栈顶的点p,t=lca(p,x),有以下几种情况。

1.t=p,那么把x加入栈就好了;

2.t!=p,那么显然t是p的祖先,那么显然栈中t->p的路径上面的点都可以加入虚树了,因为不可能会有新的lca插入到t->p的路径中来。

考虑2的具体实现,只要去除栈中的第二个元素,不断比较是否是t的祖先即可。

AC代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 250005
#define ll long long
#define inf 1e60
using namespace std;

int n,m,tot,dfsclk,bin[25],pnt[N<<1],len[N<<1],nxt[N<<1],a
,d
,fa
[18],pos
,q
;
ll val
;
struct node{
int fst
;
void add(int x,int y,int z){
if (x==y) return;
pnt[++tot]=y; len[tot]=z; nxt[tot]=fst[x]; fst[x]=tot;
}
}g1,g2;
int read(){
int x=0; char ch=getchar();
while (ch<'0' || ch>'9') ch=getchar();
while (ch>='0' && ch<='9'){ x=x*10+ch-'0'; ch=getchar(); }
return x;
}
void dfs(int x){
pos[x]=++dfsclk; int i,p;
for (i=1; bin[i]<=d[x]; i++) fa[x][i]=fa[fa[x][i-1]][i-1];
for (p=g1.fst[x]; p; p=nxt[p]){
int y=pnt[p];
if (y!=fa[x][0]){
fa[y][0]=x; val[y]=min(val[x],(ll)len[p]);
d[y]=d[x]+1; dfs(y);
}
}
}
int lca(int x,int y){
if (d[x]<d[y]) swap(x,y); int tmp=d[x]-d[y],i;
for (i=0; bin[i]<=tmp; i++)
if (tmp&bin[i]) x=fa[x][i];
for (i=17; i>=0; i--)
if (fa[x][i]!=fa[y][i]){ x=fa[x][i]; y=fa[y][i]; }
return (x==y)?x:fa[x][0];
}
ll dp(int x){
if (!g2.fst[x]) return val[x];
int p; ll tmp=0;
for (p=g2.fst[x]; p; p=nxt[p]) tmp+=dp(pnt[p]);
g2.fst[x]=0; return min(tmp,val[x]);
}
bool cmp(int x,int y){ return pos[x]<pos[y]; }
void solve(){
int i,cnt=read(),pts=1,tp=1; tot=0;
for (i=1; i<=cnt; i++) a[i]=read();
sort(a+1,a+cnt+1,cmp);
for (i=2; i<=cnt; i++)
if (lca(a[i],a[pts])!=a[pts]) a[++pts]=a[i];
q[1]=1;
for (i=1; i<=pts; i++){
int tmp=lca(a[i],q[tp]);
while (1){
if (d[q[tp-1]]<=d[tmp]){
g2.add(tmp,q[tp--],0);
if (q[tp]!=tmp) q[++tp]=tmp; break;
}
g2.add(q[tp-1],q[tp],0); tp--;
}
if (q[tp]!=a[i]) q[++tp]=a[i];
}
while (tp>1){ g2.add(q[tp-1],q[tp],0); tp--; }
printf("%lld\n",dp(1));
}
int main(){
n=read(); int i;
bin[0]=1; for (i=1; i<=18; i++) bin[i]=bin[i-1]<<1;
for (i=1; i<n; i++){
int x=read(),y=read(),z=read();
g1.add(x,y,z); g1.add(y,x,z);
}
val[1]=inf; dfs(1);
m=read(); while (m--) solve();
return 0;
}


by lych
2016.3.6
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: