题目让求得是从任意一点出发可以不回来得到的最大的价值

这应该不算特别水的树形dp了,它不止要从上往下dfs,后来海要重新dfs,根据父亲节点更新儿子节点,算是正常的树形dp中比较简单的吧。

思路:

  先从上往下dp,求出从该节点往下来在回到该节点的最大价值,不用回到该节点的最大价值以及此时停在哪一颗子树上,不用回到该节点且不停在前面的子树上的最大价值(只是不用回到该节点,不是一定不能回到该节点)

      然后重新dfs,计算出儿子节点往上能回来的最大价值以及不用回来的最大价值,显然结果就是max(往下再回来的最大价值+往上不用回来的最大价值,往下不用回来的最大价值+往上再回来的最大价值)。

代码:

#include<bits/stdc++.h>
typedef long long ll;
using namespace std;
int val[100010];
int ans[100010];
int val1[100010];//回到自身
int val2[100010];//没有回到自身
int val3[100010];//次优
int id[100010];//最后从哪个枝上面走了之后没有回到自身
vector<pair<int,int> >v[100010];
void dfs1(int s,int fa)
{
    val1[s]=val2[s]=val3[s]=val[s];
    for(int i=0;i<v[s].size();i++)
    {
        int t=v[s][i].first;
        int c=v[s][i].second;
        if(t==fa)
            continue;
        dfs1(t,s);
        int temp=max(val1[t]-2*c,0);
        val2[s]+=temp;
        val3[s]+=temp;
        if(val1[s]+val2[t]-c>val2[s])
        {
            val3[s]=val2[s];
            val2[s]=val1[s]+val2[t]-c;
            id[s]=t;
        }
        else if(val1[s]+val2[t]-c>val3[s])
            val3[s]=val1[s]+val2[t]-c;
        val1[s]+=temp;
    }
}
void dfs2(int s,int fa,int temp3,int temp4)
{//temp3表示向上走还要回来能得到的优势,temp4对应的是不回来的
    ans[s]=max(val1[s]+temp4,val2[s]+temp3);
    val2[s]+=temp3;
    val3[s]+=temp3;
    if(val2[s]<=val1[s]+temp4)//更新向上走了之后对应的结果
    {
        val2[s]=val1[s]+temp4;//这地方不更新val3[s]是因为一定用不到val3[s]了
        id[s]=fa;
    }
    else if(val3[s]<=val1[s]+temp4)
        val3[s]=val1[s]+temp4;
    val1[s]+=temp3;
    for(int i=0;i<v[s].size();i++)
    {
        int t=v[s][i].first;
        int c=v[s][i].second;
        if(t==fa)
            continue;
        int temp1=max(0,val1[s]-2*c-max(0,val1[t]-2*c));
        int temp2;
        if(id[s]==t)
            temp2=max(0,val3[s]-c-max(0,val1[t]-2*c));
        else temp2=max(0,val2[s]-c-max(0,val1[t]-2*c));
        dfs2(t,s,temp1,temp2);
    }
}
int main()
{
    int T;
    scanf("%d",&T);
    for(int cas=1;cas<=T;cas++)
    {
        int n;
        scanf("%d",&n);
        for(int i=1;i<=n;i++)
        {
            scanf("%d",&val[i]);
            v[i].clear();
        }
        int a,b,c;
        for(int i=1;i<n;i++)
        {
            scanf("%d%d%d",&a,&b,&c);
            v[a].push_back(make_pair(b,c));
            v[b].push_back(make_pair(a,c));
        }
        memset(id,-1,sizeof(id));
        dfs1(1,-1);
        dfs2(1,-1,0,0);
        printf("Case #%d:\n",cas);
        for(int i=1;i<=n;i++)
            printf("%d\n",ans[i]);
    }
    return 0;
}