树形dp

树形dp,意思就是在树上的dp, 看了看紫书,讲了三个大点把,一个是树的最大独立集,另外一个是树的重心,最后一个是树的最长路径。给的三个例题,下面就从例题说起

第一个:工人的请愿书 uva 12186

这个题目给定一个公司的树状结构,每个员工都有唯一的一个直属上司,老板编号为0,员工1-n,只有下一级的工人请愿书不小于T%时,这个中级员工,才会签字传递给它的直属上司,问老板收到请愿书至少需要多少各个工人签字

用dp(u)表示u给上级发信至少需要多少工人,那么可以假设u有k个节点,所以需要c = (T*k - 1) / 100 + 1个直接下属,把所有的子节点的dp值从小到大排序,前c个加起来就是答案。

树形dp就是从根节点一层一层往下找,找到最优的子结构,树形的最优子结构就是子树,我是用的记忆化搜索。代码如下:

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
const int maxn = 1e5 + 10;
vector<int> sons[maxn];
int n, t;
int dp(int u)
{
    if (sons[u].empty())
        return 1;
    int k = sons[u].size();
    vector<int> d;
    for (int i = 0; i < k; i++)
        d.push_back(dp(sons[u][i]));
    sort(d.begin(), d.end());
    int c = (k * t - 1) / 100 + 1;
    int ans = 0;
    for (int i = 0; i < c; i++)
        ans += d[i];
    return ans;
}
void init()
{
    for (int i = 0; i <= n; i++)
        sons[i].clear();
}
int main()
{
    while (~scanf("%d%d", &n, &t) && n + t)
    {
        init();
        int t;
        for (int i = 1; i <= n; i++)
        {
            scanf("%d", &t);    
            sons[t].push_back(i);
        }
        printf("%d\n", dp(0));
    }
    return 0;
}
View Code

第二个:Hali-Bula的晚会,poj3342, uva1220

这个几乎是求树的最大独立集的模板题,就是多了一个要求,判断是否唯一

考虑一个点,只有两种情况,选它,不选它

所以用 d[u][0]表示以u为根的子树中,不选u点能得到的最大人数 f[u][0]表示方案唯一性,如果f[u][0] = 1说明唯一,0说明不唯一

d[u][1]表示以u为根,选u点能得到的最大值。

状态转移方程就是d[u][1]  = sum{d[v][0]}| v是u的子节点

d[u][0] = sum{max(d[v][0], d[v][1])},所以代码如下:

#include <cstdio>
#include <iostream>
#include <vector>
#include <algorithm>
#include <string>
#include <cstring>
#include <map>
using namespace std;
const int maxn = 240;
vector<int> sons[maxn];
map<string, int> mp;
int f[maxn][2], d[maxn][2];
int n, cnt;
void init()
{
    memset(d, -1, sizeof(d));
    for (int i = 0; i <= n; i++)
        sons[i].clear();
    mp.clear();
}
int dp(int u, int flag)
{
    f[u][flag] = 1;
    if (sons[u].empty() && flag)
        return 1;
    if (sons[u].empty() && !flag)
        return 0;
    int k = sons[u].size();
    int sum = 0;
    if (flag == 1)
    {
        sum++;
        for (int i = 0; i < k; i++)
        {
            int v = sons[u][i];
            if (d[v][0] == -1)
                d[v][0] = dp(v, 0);
            sum += d[v][0];
            f[u][1] &= f[v][0];
        }
    }
    else
    {
        for (int i = 0; i < k; i++)
        {
            int v = sons[u][i];
            if (d[v][1] == -1)
                d[v][1] = dp(v, 1);
            if (d[v][0] == -1)
                d[v][0] = dp(v, 0);
            if (d[v][0] == d[v][1])
                f[u][flag] = 0;
            if (d[v][0] > d[v][1])
                f[u][0] &= f[v][0];
            else
                f[u][0] &= f[v][1];
            sum += max(d[v][1], d[v][0]);
        }
    }
    return sum;
}

int main()
{
    while (~scanf("%d", &n) && n)
    {
        init();
        string s, tmp;
        cin >> s;
        mp[s] = 0;
        cnt = 1;
        for (int i = 1; i < n; i++)
        {
            cin >> s >> tmp;
            if (mp.count(s) == 0)
                mp[s] = cnt++;
            if (mp.count(tmp) == 0)
                mp[tmp] = cnt++;
            sons[mp[tmp]].push_back(mp[s]);
        }
        int t1 = dp(0, 0); int t2 = dp(0, 1);
        bool flag = true;
        //cout << f[0][0] << endl;
        //cout << f[0][1] << endl;
        //cout << t2 << endl;
        if (t1 == t2)
            flag = false;
        if (t1 < t2 && f[0][1] == 0)
            flag = false;
        if (t1 > t2 && f[0][0] == 0)
            flag = false;
        printf("%d %s\n", max(t1, t2), flag ? "Yes" : "No");
    }
    return 0;
}
View Code

第三个:完美的服务 poj3398, uva 1218

这个和第二题差不多,就是状态多了一个,因为有个条件是每台计算机连接的服务器恰好是一个,所以多了一个状态,我是用记忆化搜索来写的,写完之后一直wrong answer,后来对比了网上的代码发现把inf设的太大了,估计是后来加着加着越界了,所以wrong了,后来改了就好了,不过记忆化写起来比递推代码长多了。。。

记忆化代码:

#include <cstdio>
#include <iostream>
#include <vector>
using namespace std;
const int maxn = 10010;
const int inf = (1e6);
vector<int> sons[maxn];
int n;
int d[maxn][3];
void init()
{
    for (int i = 0; i <= n; i++)
        d[i][0] = d[i][1] = d[i][2] = inf;
    for (int i = 0; i <= n; i++)
        sons[i].clear();
}
int dp(int u, int pre, int f)
{
    if (sons[u].size() == 1 && sons[u][0] == pre)
        if (f == 0)
            return 1;
        else if (f == 1)
            return 0;
        else
            return inf;
    int k = sons[u].size();
    int sum = 0;
    if (f == 0)//u is server
    {
        sum++;
        for (int i = 0; i < k; i++)
        {
            int v = sons[u][i];
            if (pre == v)
                continue;
            if (d[v][0] == inf)
                d[v][0] = dp(v, u, 0);
            if (d[v][1] == inf)
                d[v][1] = dp(v, u, 1);
            sum += min(d[v][0], d[v][1]);
        }
    }
    else if (f == 1)//u is not server, his father is server
    {
        for (int i = 0; i < k; i++)
        {
            int v = sons[u][i];
            if (v == pre)
                continue;
            if (d[v][2] == inf)
                d[v][2] = dp(v, u, 2);
            sum += d[v][2];
        }
    }
    else//u is not server and his father is not server;
    {
        int ans = inf;
        for (int i = 0; i < k; i++)
        {
            int v = sons[u][i];
            if (pre == v)
                continue;
            if (d[u][1] == inf)
                d[u][1] = dp(u, pre, 1);
            if (d[v][2] == inf)
                d[v][2] = dp(v, u, 2);
            if (d[v][0] == inf)
                d[v][0] = dp(v, u, 0);
            ans = min(ans, d[u][1] - d[v][2] + d[v][0]);
        }
        sum += ans;
    }
    return sum;
}
int main()
{
    int a, b;
    while (~scanf("%d", &n) && n != -1)
    {
        if (n == 0)
            continue;
        init();
        for (int i = 1; i < n; i++)
        {
            scanf("%d %d", &a, &b);
            sons[a].push_back(b);
            sons[b].push_back(a);
        }
        int t1 = dp(1, 0, 0); int t2 = dp(1, 0, 2);
        //cout << t1 << endl;
        //cout << t2 << endl;
        printf("%d\n", min(t1, t2));
    }
    return 0;
}
View Code

递推代码:

#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>
using namespace std;
const int maxn = 11000;
const int inf = 1e6;
vector<int> sons[maxn];
int n;
int d[maxn][3];
void init()
{
    for (int i = 0; i <= n + 10; i++)
    {
        d[i][0] = 1; d[i][1] = 0; d[i][2] = inf;
    }
    for (int i = 0; i <= n; i++)
        sons[i].clear();
}
void dp(int u, int pre)
{
    int k = sons[u].size();
    for (int i = 0; i < k; i++)
    {
        int v = sons[u][i];
        if (pre == v)
            continue;
        dp(v, u);
        d[u][0] += min(d[v][0], d[v][1]);
        d[u][1] += d[v][2];
        d[u][2] = min(d[u][2], d[v][0] - d[v][2]);
    }
    d[u][2] += d[u][1];
}
int main()
{
    int a, b;
    while (~scanf("%d", &n) && n != -1)
    {
        if (n == 0)
            continue;
        init();
        for (int i = 1; i < n; i++)
        {
            scanf("%d %d", &a, &b);
            sons[a].push_back(b);
            sons[b].push_back(a);
        }
        dp(1, 0);
        printf("%d\n", min(d[1][0], d[1][2]));
    }
    return 0;
}
View Code

 

posted @ 2015-08-14 20:03  Howe_Young  阅读(281)  评论(0编辑  收藏  举报