您的位置:首页 > 其它

poj1947 ----- 树形DP - 分组背包做法

2015-11-08 11:24 281 查看
题意 : 给你一个树 , 求出得到p个节点的子树所需要的最少的切割边数 。

思路 : 很明显的DP ,但是写起来 各种 烦 , 足足写了半天才改好 。 有一种左儿子右兄弟的DP方法看起来好像特别好想状态转移方程 。但是我习惯用常规的建图方法 和 DP

做法 : 状态 dp[rt][j] 表示以 rt 为根的(并且rt一定在子树上)的得到节点树为j的最小切割数 。 (注意 , 在树状DP中根节点是否包扩是很重要的 , 有的时候限定一定包括根节点会很方便,而且主要的思考也该在这上面 。 ) !!!

转移 : 1.dp[rt][j] = dp[rt][j] ++ ;(这个转移很特殊 , 一般的背包是等于自己 ,但是这里如果新的子树不选择节点那么要 + 1 , 因为要切割开rt ,与 这个 节点的边才可以是0) 。

2.dp[rt][j] = dp[rt][j-k] + dp[son][k] ; (使用了滚动数组,所以要从大到小, 防止数据被覆盖 ) 。

总体来说 , 就是每一颗子树看作是一个背包 , 选几个都要考虑进来 , 但是使用了滚动数组 , 原来因该是 dp[rt][i][j] i表示现在在考虑前i颗子树 。 其余不变

初值 : dp[k][0] = dp[k][1](考虑到第0颗子树,那也就是只有一个节点,就是rt本身,不用减去 = 0) = 0 ;

ans = dp[k][p] ; (p为所求的 , 且k==root) ;

ans = dp[k][p] + 1; k != root (这里是表示与根节点无关系,所以要减去与root的那个连线 也就是 + 1) ;

#include <stdio.h>
#include <string>
#include <string.h>
#include <queue>
#include <stack>
#include <map>
#include <iostream>
#include <stdlib.h>
#include <math.h>
#include <algorithm>
#define mod 10e9+7;
#define inf 0x3f3f3f3f;
#define mem0(x , y)  memset(x , y , sizeof(x))
using namespace std;
const int MAX = 2000 ;
struct node{
int s, e , next ;
}Edge[MAX]  ;
int head[MAX] , cnt  ;
int dp[MAX][MAX]  ;
void init(){
mem0(head , -1) ;
cnt = 0 ;
mem0(dp , 0x3f3f3f3f) ;
for(int i=0;i<MAX;i++){
dp[i][0] = 0 ;
dp[i][1] = 0 ;
}
}
void add(int s ,int e) {
Edge[cnt].s  = s ; Edge[cnt].e = e ; Edge[cnt].next = head[s] ;
head[s] = cnt ++ ;
}
int son[MAX]  ;
int p ;
int ans = 0x3f3f3f3f ;
void dfs(int rt , int pre){
son[rt] = 1 ;
int flag = 0 ;
for(int i = head[rt] ; i != -1 ; i = Edge[i].next){
int v = Edge[i].e  ;
if(v == pre) continue ;
flag ++ ;
dfs(v , rt) ;
son[rt] += son[v] ;
for(int i=son[rt];i>=1;i--){
dp[rt][i] ++ ;
for(int j=1;j<=son[v]&&j+1<=i;j++){
dp[rt][i] = min(dp[rt][i] , dp[rt][i-j] + dp[v][j]) ;
}
///printf("dp[%d][%d] = %d\n" , rt , i , dp[rt][i]) ;
}
}
if(rt != 1)  ans = min(ans , dp[rt][p]+1) ;
else ans = min(ans , dp[rt][p]) ;
}
int main(){
///freopen("input" , "r" , stdin ) ;
int  n;
int  a, b ;
init() ;
scanf("%d%d",&n,&p) ;
for(int i=0;i<n-1;i++){
scanf("%d%d",&a,&b) ;
add(a , b) ; add(b , a) ;
}
dfs(1 , -1) ;
printf("%d\n",ans) ;
}


代码如下 :
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: