好迷的树形dp。。。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define maxv 500500 #define maxe 1000500 using namespace std; int n,m,d,x,y,nume=0,g[maxv],dp1[maxv][23],dp2[maxv][23],w[maxv],fath[maxv]; bool mark[maxv][23]; struct edge { int v,nxt; }e[maxe]; void addedge(int u,int v) { e[++nume].v=v; e[nume].nxt=g[u]; g[u]=nume; } void dp(int x) { int flag=0; for (int i=g[x];i;i=e[i].nxt) { int v=e[i].v; if (v!=fath[x]) { fath[v]=x; dp(v);flag=1; for (int j=1;j<=d;j++) mark[x][j]|=mark[v][j-1]; } } if (!flag) { if (mark[x][0]) dp1[x][0]=dp2[x][0]=w[x]; for (int i=1;i<=d;i++) { dp2[x][i]=w[x]; dp1[x][i]=0; } return; } for (int i=1;i<=d;i++) for (int j=g[x];j;j=e[j].nxt) { int v=e[j].v; if (v!=fath[x]) dp1[x][i]+=dp1[v][i-1]; } dp2[x][d]+=w[x]; for (int i=g[x];i;i=e[i].nxt) { int v=e[i].v; if (v!=fath[x]) dp2[x][d]+=dp1[v][d]; } for (int i=d-1;i>=0;i--) { int ret=0; for (int j=g[x];j;j=e[j].nxt) { int v=e[j].v; if (v!=fath[x]) ret+=dp1[v][i]; } dp2[x][i]=dp2[x][i+1]; for (int j=g[x];j;j=e[j].nxt) { int v=e[j].v; if (v!=fath[x]) dp2[x][i]=min(dp2[x][i],dp2[v][i+1]+ret-dp1[v][i]); } } dp1[x][0]=dp2[x][0]; for (int i=1;i<=d;i++) dp1[x][i]=min(dp1[x][i],dp1[x][i-1]); for (int i=d-1;i>=0;i--) if (!mark[x][i]) dp1[x][i]=min(dp1[x][i],dp1[x][i+1]); dp2[x][0]=dp1[x][0]; } int main() { scanf("%d%d",&n,&d); for (int i=1;i<=n;i++) scanf("%d",&w[i]); scanf("%d",&m); for (int i=1;i<=m;i++) { scanf("%d",&x); mark[x][0]=true; } for (int i=1;i<=n-1;i++) { scanf("%d%d",&x,&y); addedge(x,y);addedge(y,x); } dp(1); printf("%d\n",dp1[1][0]); return 0; }