[树形DP]JZOJ 3347 树的难题

Description

 

Input

输入文件 为split.in 。
第一行 包含 一个正整数 T,表示有T组测试数据 。接下来 依次是 T组测试数 据。
每组测试数 据的第一行包含个正整数N。
第二行包含 N个 0、1、2之一 的整数,依次 表示点 1到点 N的颜色。其中0表示黑色, 1表示白色, 2表示灰色。
接下来 N-1行 ,每行为三个整数 ui、vi、ci,表示 一条权值等于 ci的边 (ui, vi)。
 

Output

输出文件为 split.out 。
输出 T行 ,每一个整数, 依次 表示 每组测试数据 的答案。
 

Sample Input

1
5
0 1 1 1 0
1 2 5
1 3 3
5 2 5
2 4 16

Sample Output

10
【样例解释】
花费 10 的代价删去 边(1, 2)和边(2, 5)。
 
 

Data Constraint

对于 10% 的数据: 1 ≤ N ≤ 10。
对于 30% 的数据: 1 ≤ N ≤ 50 0。
对于 60% 的数据: 1 ≤ N ≤ 50 000 。
对于 100% 的数据: 1 ≤ N ≤ 300 000 ,1 ≤ T ≤ 5,0 ≤ ci ≤ 10^9。

分析

显然最优解只有三种:

树中有若干白色点,无黑点

有若干黑点,无白点

有若干黑点,一白点

设成DP,转移较简单

注意DP不能用DFS转移,会爆栈

其实我不明白为什么他不出个菊花图来卡我的DP,DP复杂度约为O(n*最大儿子数)

 

#pragma GCC optimize(2)
#include <iostream>
#include <cstdio>
#include <memory.h>
#include <algorithm>
#include <queue>
using namespace std;
typedef long long ll;
const int N=3e5+10;
const ll Inf=3e15;
struct Graph {
    int v,nx;
    ll w;
}g[2*N];
ll f[N][3];
int cnt,list[N],fa[N],dep[N];
int w[N],a[N];
int t,n;

void Add(int u,int v,int w) {
    g[++cnt]=(Graph){v,list[u],w};list[u]=cnt;
    g[++cnt]=(Graph){u,list[v],w};list[v]=cnt;
}
 
void BFS(int v0) {
    queue<int> q;
    while (!q.empty()) q.pop();
    q.push(v0);dep[v0]=1;
    while (!q.empty()) {
        int u=q.front();q.pop();
        for (int i=list[u];i;i=g[i].nx)
            if (!dep[g[i].v]) {
                dep[g[i].v]=dep[u]+1;fa[g[i].v]=u;
                q.push(g[i].v);
            }
    }
}

bool CMP(int a,int b) {
    return dep[a]>dep[b];
}

int main() {
    for (scanf("%d",&t);t;t--) {
        cnt=0;memset(list,0,sizeof list);
        scanf("%d",&n);
        for (int i=1;i<=n;i++) scanf("%d",&w[i]),a[i]=i;
        for (int i=1,u,v,w;i<n;i++) scanf("%d%d%d",&u,&v,&w),Add(u,v,w);
        memset(f,0,sizeof f);memset(dep,0,sizeof dep);
        BFS(1);
        sort(a+1,a+n+1,CMP);
        for (int i=1;i<=n;i++) {
            int u=a[i];
            if (w[u]==0) {
                f[u][0]=Inf;f[u][2]=Inf;
                for (int j=list[u];j;j=g[j].nx)
                    if (g[j].v!=fa[u]) {
                        f[u][1]+=min(min(f[g[j].v][0],f[g[j].v][2])+g[j].w,f[g[j].v][1]);
                        ll t=0;
                        for (int k=list[u];k;k=g[k].nx)
                            if (k!=j&&g[k].v!=fa[u])
                                t+=min(min(f[g[k].v][0],f[g[k].v][2])+g[k].w,f[g[k].v][1]);
                        f[u][2]=min(f[u][2],f[g[j].v][2]+t);
                    }
            }
            if (w[u]==1) {
                f[u][1]=Inf;
                for (int j=list[u];j;j=g[j].nx)
                    if (g[j].v!=fa[u]) {
                        f[u][0]+=min(f[g[j].v][0],min(f[g[j].v][1],f[g[j].v][2])+g[j].w);
                        f[u][2]+=min(f[g[j].v][1],min(f[g[j].v][0],f[g[j].v][2])+g[j].w);
                    }
            }
            if (w[u]==2) {
                f[u][2]=Inf;
                for (int j=list[u];j;j=g[j].nx)
                    if (g[j].v!=fa[u]) {
                        f[u][0]+=min(f[g[j].v][0],min(f[g[j].v][1],f[g[j].v][2])+g[j].w);
                        f[u][1]+=min(f[g[j].v][1],min(f[g[j].v][0],f[g[j].v][2])+g[j].w);
                        ll t=0;
                        for (int k=list[u];k;k=g[k].nx)
                            if (k!=j&&g[k].v!=fa[u])
                                t+=min(f[g[k].v][1],min(f[g[k].v][0],f[g[k].v][2])+g[k].w);
                        f[u][2]=min(f[u][2],f[g[j].v][2]+t);
                    }
            }
        }
        printf("%lld\n",min(f[1][0],min(f[1][1],f[1][2])));
    }
}
View Code

 

posted @ 2019-07-10 21:22  Vagari  阅读(170)  评论(0编辑  收藏  举报