网络流练习题 染色

染色
【问题描述】
有一棵𝑛个点的有根树,节点的编号为0到𝑛 − 1,其中根节点的编号为0。
有𝑚种颜色。给一个点添上第𝑖种颜色的代价为𝑤 ( 。每个点必须被添上恰好
两种不同的颜色。
定义一个点的控制集合为这个点的所有儿子与这个点自己构成的集合。注意
如果𝑏是𝑎的儿子且𝑐是𝑏的儿子,𝑐不是𝑎的儿子。
我们称一个集合是多样的,当且仅当集合中的每个点的颜色都不同。这里一
个点的颜色可以是其被添上的两种颜色的任意一种。
比如说,假设集合里有两个点,这两个点被添上的颜色都是(1,2),那么集
合是多样的,因为两个点的颜色可以是1和2。但是如果有集合里有三个点,这三
个点被添上的颜色都是(1,2),那么这个集合就不是多样的了。
称一个染色方案合法,当且仅当每个点的控制集合都是多样的。求一个染色
代价最小的合法方案。如果不存在合法方案,则输出−1。
【输入格式】
输入文件的第一行包含两个整数𝑛和𝑚。
接下来一行包含𝑛 − 1个整数,分别代表1到𝑛 − 1号点的父亲。保证给定的
父亲关系构成一棵树,而且𝑖号点的父亲节点编号一定小于𝑖。
接下来一行包含𝑚个整数,第𝑖个整数为𝑤 ( 。
【输出格式】
输出一个整数,代表最小的染色代价。如果不存在合法方案,则输出−1。
【样例输入】
3 3
0 0
1 2 3
【样例输出】
10
【样例解释】
在最优方案中,0号节点和1号节点被添上的颜色是(1,2),2号节点被添上的
颜色是(1,3)。代价为1 + 2 + 1 + 2 + 1 + 3 = 10。

对于 10%的数据,𝑛,𝑚 ≤ 5。
对于 30%的数据,𝑚 ≤ 8。
另外有 20%的数据,𝑤 2 = 𝑤 3 = ⋯ = 𝑤 m 。
对于 100%的数据,1 ≤ 𝑛 ≤ 30,2 ≤ 𝑚 ≤ 30,1 ≤ 𝑤 ( ≤ 100。

分析:挺不错的一道题,就是送的分有点多了.
   30%的数据可以状压. 贪心能够骗到90分.

   正确方法是树形dp+网络流.  令f[i][j]表示以i为根的子树中,i的颜色是j的最小染色代价. 那么f[i][j] = Σf[son[i]][c],其中c互不相同并且不同于j. 这个限制不好直接处理.  

   仔细分析一波,能发现这其实是一个最小权匹配问题! 利用费用流来实现.

   S连向i以及i的所有儿子,容量为1,费用为0.  每种颜色j连向T,容量为1,费用为0.  每个点i连向颜色j,费用为f[i][j],容量为1. 在每一层跑一次费用流即可.

   选两个颜色,起限制作用的其实是表现出来的颜色. 每次只需要选费用最小的颜色和当前决策的颜色就能构成一个合法的二元组.

#include <vector>
#include <queue>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

const int maxn = 1010,inf = 0x7fffffff;
bool flag = true;
int n,m,fa[maxn],f[maxn][maxn],a[maxn],ans = inf,S,T,anss,vis[maxn],vis2[maxn];
int head[maxn],to[maxn],nextt[maxn],w[maxn],cost[maxn],tot = 2,d[maxn];
vector <int>son[maxn];

void add(int x,int y,int z,int p)
{
    cost[tot] = p;
    w[tot] = z;
    to[tot] = y;
    nextt[tot] = head[x];
    head[x] = tot++;

    cost[tot] = -p;
    w[tot] = 0;
    to[tot] = x;
    nextt[tot] = head[y];
    head[y] = tot++;
}

bool spfa()
{
    for (int i = 1; i <= T; i++)
        d[i] = inf;
    memset(vis,0,sizeof(vis));
    memset(vis2,0,sizeof(vis2));
    queue <int> q;
    d[S] = 0;
    vis[S] = 1;
    q.push(S);
    while (!q.empty())
    {
        int u = q.front();
        q.pop();
        vis[u] = 0;
        for (int i = head[u];i;i = nextt[i])
        {
            int v = to[i];
            if (w[i] && d[v] > d[u] + cost[i])
            {
                d[v] = d[u] + cost[i];
                if (!vis[v])
                {
                    vis[v] = 1;
                    q.push(v);
                }
            }
        }
    }
    return d[T] < inf;
}

int dfs(int u,int f)
{
    if (u == T)
    {
        anss += f * d[u];
        return f;
    }
    int res = 0;
    vis2[u] = 1;
    for (int i = head[u];i;i = nextt[i])
    {
        int v = to[i];
        if (!vis2[v] && d[v] == d[u] + cost[i] && w[i])
        {
            int temp = dfs(v,min(f - res,w[i]));
            w[i] -= temp;
            w[i ^ 1] += temp;
            res += temp;
            if (res == f)
                return res;
        }
    }
    return res;
}

void dinic()
{
    while (spfa())
        dfs(S,inf);
}

void dfss(int u)
{
    if (!flag)
        return;
    if (son[u].size() == 0)
    {
        for (int i = 1; i <= m; i++)
        {
            if (i == 1)
                f[u][i] = a[1] + a[2];
            else
                f[u][i] = a[1] + a[i];
        }
        return;
    }
    if (son[u].size() + 1 > m)
    {
        flag = false;
        return;
    }
    for (int i = 0; i < son[u].size(); i++)
    {
        int v = son[u][i];
        dfss(v);
    }
    for (int i = 1; i <= m; i++)
    {
        memset(head,0,sizeof(head));
        tot = 2;
        for (int j = 0; j < son[u].size(); j++)
        {
            int v = son[u][j];
            add(S,v,1,0);
        }
        for (int j = 1; j <= m; j++)
            if (j != i)
                add(j + n,T,1,0);
        for (int j = 0; j < son[u].size(); j++)
            for (int k = 1; k <= m; k++)
            {
                    int v = son[u][j];
                    add(v,k + n,1,f[v][k]);
            }
        if (i == 1)
            f[u][i] = a[1] + a[2];
        else
            f[u][i] = a[1] + a[i];
        anss = 0;
        dinic();
        f[u][i] += anss;
    }
}

int main()
{
    scanf("%d%d",&n,&m);
    S = n + m + 1;
    T = S + 1;
    for (int i = 2; i <= n; i++)
    {
        scanf("%d",&fa[i]);
        fa[i]++;
        son[fa[i]].push_back(i);
    }
    for (int i = 1; i <= m; i++)
        scanf("%d",&a[i]);
    sort(a + 1,a + 1 + m);
    dfss(1);
    for (int i = 1; i <= m; i++)
        ans = min(ans,f[1][i]);
    if (!flag)
        printf("-1");
    else
        printf("%d\n",ans);

    return 0;
}

 

posted @ 2018-03-22 18:56  zbtrs  阅读(379)  评论(0编辑  收藏  举报