博客园 首页 私信博主 显示目录 隐藏目录 管理

树形DP

poj2342
题意:子节点和父亲节点不能同时选,问最大价值。
分析:f[i][0]表示不选节点i,f[i][1]表示选择节点i。
假设u为父亲,v为儿子,可得f[u][1] += f[v][0], f[u][0] += max(f[v][0], f[v][1])。
答案为max(f[rt][0], f[rt][1])。

#include<bits/stdc++.h>

using namespace std;

const int N = 6010;

template <typename T>
T read(){
	T n(0), f(1);
	char ch = getchar();
	for(; !isdigit(ch); ch=getchar()) if(ch == '-') f = -1;
	for(; isdigit(ch); ch=getchar()) n = n*10 + ch-48;
	return n*f;
}

int n;
int f[N][2], a[N], fa[N], ind[N], vis[N];

void dfs(int u){
	vis[u] = 1;
	for(int i = 1; i <= n; i++){
		if(!vis[i] && fa[i] == u){
			dfs(i);
			f[u][1] += f[i][0];
			f[u][0] += max(f[i][1], f[i][0]);
		}
	}
}

int main(){
	n = read<int>();
	for(int i = 1; i <= n; i++) a[i] = read<int>();
	for(int i = 1; i < n; i++){
		int x, y;
		x = read<int>(), y = read<int>();
		fa[x] = y;
		ind[x]++;
	}

	for(int i = 1; i <= n; i++) f[i][1] = a[i];
	for(int i = 1; i <= n; i++){
		if(!ind[i]){
			dfs(i);
			printf("%d\n", max(f[i][0], f[i][1]));
			break;
		}
	}

	return 0;
}

hdu1520(同上,数据加强版)
分析:将dfs内循环1~n改为只找当前节点的子树。

#include<bits/stdc++.h>

using namespace std;

#define pb push_back

const int N = 6010;

template <typename T>
T read(){
	T n(0), f(1);
	char ch = getchar();
	for(; !isdigit(ch); ch=getchar()) if(ch == '-') f = -1;
	for(; isdigit(ch); ch=getchar()) n = n*10 + ch-48;
	return n*f;
}

int n;
int f[N][2], a[N], fa[N], ind[N], vis[N];
vector<int> s[N];

void dfs(int u){
	vis[u] = 1;
	for(int j = 0; j < s[u].size(); j++){
		int i = s[u][j];
		if(!vis[i]){
			dfs(i);
			f[u][1] += f[i][0];
			f[u][0] += max(f[i][1], f[i][0]);
		}
	}
}

int main(){
	n = read<int>();
	for(int i = 1; i <= n; i++) a[i] = read<int>();
	for(int i = 1; i < n; i++){
		int x, y;
		x = read<int>(), y = read<int>();
		s[y].pb(x);
		ind[x]++;
	}

	for(int i = 1; i <= n; i++) f[i][1] = a[i];
	for(int i = 1; i <= n; i++){
		if(!ind[i]){
			dfs(i);
			printf("%d\n", max(f[i][0], f[i][1]));
			break;
		}
	}

	return 0;
}

hdu2196
题意:求树上每个点能到达的最远距离
分析:两次dfs。
f[i][0]表示走i的子树的最长距离,f[i][1]表示走i的子树的次长距离(与f[i][0]是不同的儿子),f[i][2]表示往上经过i的父亲的最长距离。
第一次从下往上,处理好f[i][0]和f[i][1]。
当f[v][0]+w[i] > f[u][0]时 :f[u][1] = f[u][0]; f[u][0] = f[v][0] + w[i];
否则当f[v][0]+w[i] > f[u][1]时 :f[u][1] = f[v][0] + w[i];
第二次从上往下更新子节点的f[v][2]。
如果u的最长边是经过v的(f[u][0] == f[v][0] + w[i]),那么通过u只能去u的次长边的树或者u往上父亲的树。//不知道题目保不保证最长边只有一条...?
否则如果v不属于u的最长边,那么通过u只能去u的最长边的树或者u的父亲的树。

#include<bits/stdc++.h>

using namespace std;

const int N = 10010;

template <typename T>
T read(){
	T n(0), f(1);
	char ch = getchar();
	for(; !isdigit(ch); ch=getchar()) if(ch == '-') f = -1;
	for(; isdigit(ch); ch=getchar()) n = n*10 + ch-48;
	return n*f;
}

int n, e;
int f[N][3], begin[N];
struct node{
	int to, nxt, w;
}E[N<<1];

void add(int x, int y, int z){
	E[++e].to = y; E[e].w = z; E[e].nxt = begin[x];
	begin[x] = e;
}

void dfs(int u, int fa = 0){
	for(int i = begin[u]; i; i = E[i].nxt){
		int v = E[i].to;
		if(v == fa) continue;
		dfs(v, u);	
		if(f[v][0] + E[i].w > f[u][0]){
			f[u][1] = f[u][0];
			f[u][0] = f[v][0] + E[i].w;
		}
		else if(f[v][0] + E[i].w > f[u][1]) f[u][1] = f[v][0] + E[i].w;
	} 
}

void dfs_(int u, int fa = 0){
	for(int i = begin[u]; i; i = E[i].nxt){
		int v = E[i].to;
		if(v == fa) continue;
		if(f[v][0] + E[i].w == f[u][0]){
			f[v][2] = max(f[u][2], f[u][1])+E[i].w;
		}
		else f[v][2] = max(f[u][2], f[u][0])+E[i].w;
		dfs_(v, u);
	}
}

int main(){
	n = read<int>();
	for(int i = 2; i <= n; i++){
		int x, y;
		x = read<int>(); y = read<int>();
		add(x, i, y);
	}

	dfs(1);
	dfs_(1);
	for(int i = 1; i <= n; i++) printf("%d ", max(f[i][0], f[i][2]));

	return 0;
}

poj3107
题意:求树的重心
分析:(感觉不算是树形DP了,直接dfs回溯的时候计数)
son[i]表示i的子树个数,f[i]表示除去i节点后剩余分支子树最大值。
答案即为min(f[i]), 按照字典序输出。
假设v是u的儿子,
f[u] = max(n-son[u], max(son[v]));
//n-son[u]表示u以上的节点个数,我们将u视作根节点时,u以上即为u的一棵子树。然后与u的其他子树中的(节点树)最大子树选max

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>

using namespace std;

const int oo = 0x7f7f7f7f;
const int N = 50010;

template <typename T>
T read(){
	T n(0), f(1);
	char ch = getchar();
	for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = -1;
	for(; isdigit(ch); ch = getchar()) n = n*10 + ch-48;
	return n*f;
}

int n, e, mn;
int to[N<<1], nxt[N<<1];
int son[N], begin[N], smx[N], f[N];

void add(int x, int y){
	to[++e] = y; nxt[e] = begin[x]; begin[x] = e;
	to[++e] = x; nxt[e] = begin[y]; begin[y] = e;
}

void chkmn(int& a, int b){
	a = a < b ? a : b;
}

void dfs(int u, int fa = 0){
	for(int i = begin[u]; i; i = nxt[i]){
		int v = to[i];
		if(v == fa) continue;
		dfs(v, u);
		son[u] += son[v];
		if(son[v] > smx[u]) smx[u] = son[v];
	}
	f[u] = max(n-son[u], smx[u]);
}

int main(){
	n = read<int>();
	for(int i = 1; i < n; ++i){
		int x, y;
		x = read<int>();
		y = read<int>();
		add(x, y);
	}

	for(int i = 1; i <= n; ++i) son[i] = 1, smx[i] = 0;
	dfs(1);
	mn = oo;
	for(int i = 1; i <= n; ++i) chkmn(mn, f[i]);
	for(int i = 1; i <= n; ++i) if(f[i] == mn) printf("%d ", i);
	puts("");
	return 0;
}

poj3140
题意:删除某边后剩余两个分支差值最小是多少
分析:把上面的程序瞎**改一下,不过这题每个节点都有一个值,所以总数不是n而是sum(总值)。
//我真的都没思考,随便改一改试一试WA了x次之后竟然对了...

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>

using namespace std;

typedef long long ll;
const ll oo = 1e18+7;
const int N = 100010;

template <typename T>
T read(){
	T n(0), f(1);
	char ch = getchar();
	for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = -1;
	for(; isdigit(ch); ch = getchar()) n = n*10 + ch-48;
	return n*f;
}

int n, m, e;
ll mn, sum;
ll to[N<<1], nxt[N<<1];
ll son[N], begin[N], f[N];

void init(){
	sum = 0;
	for(int i = 1; i <= m<<1; ++i) son[i] = nxt[i] = 0;
	for(int i = 1; i <= n; ++i) begin[i] = f[i] = 0;
}

void add(int x, int y){
	to[++e] = y; nxt[e] = begin[x]; begin[x] = e;
	to[++e] = x; nxt[e] = begin[y]; begin[y] = e;
}

void chkmn(ll& a, ll b){
	a = a < b ? a : b;
}

void dfs(int u, int fa = 0){
	for(int i = begin[u]; i; i = nxt[i]){
		int v = to[i];
		if(v == fa) continue;
		dfs(v, u);
		son[u] += son[v];
	}
	f[u] = sum-son[u] > son[u] ? sum-son[u]*2 : 2*son[u]-sum;
}

int main(){
	int tot = 0;
	while(scanf("%d%d", &n, &m) && n && m){
		if(!n && !m) break;

		init();
		for(int i = 1; i <= n; i++) son[i] = read<int>(), sum += son[i];
		for(int i = 1; i <= m; ++i){
			int x, y;
			x = read<int>();
			y = read<int>();
			add(x, y);
		}
		
		printf("Case %d: ", ++tot);
		//for(int i = 1; i <= n; ++i) son[i] = 1;
		dfs(1);
		mn = oo;
		for(int i = 1; i <= n; ++i){
			//printf("%d %lld %lld %lld\n", i, f[i], n-son[i], son[i]);
			chkmn(mn, f[i]);
		}
		printf("%lld\n", mn);
	}
	return 0;
}
posted @ 2017-10-25 15:47  Hanser  阅读(180)  评论(0编辑  收藏  举报