bzoj 2599: [IOI2011]Race
依次处理子树,用t数组表示取了x价值的最小深度,这样就可以避免每次重复计算一棵子树的情况,每次统计的答案就是min{d[x]+t[K-x]}
1 #include<bits/stdc++.h> 2 #define N 100005 3 #define M 10000005 4 #define LL long long 5 #define inf 0x3f3f3f3f 6 using namespace std; 7 inline int ra() 8 { 9 int x=0,f=1; char ch=getchar(); 10 while (ch<'0' || ch>'9') {if (ch=='-') f=-1; ch=getchar();} 11 while (ch>='0' && ch<='9') {x=x*10+ch-'0'; ch=getchar();} 12 return x*f; 13 } 14 int n,K,cnt,sum,root,ans; 15 int t[1000005],head[200005],size[200005],f[200005],dis[200005],d[200005]; 16 bool vis[200005]; 17 struct edge{int to,next,v;}e[400005]; 18 void insert(int x, int y, int v){e[++cnt].next=head[x]; e[cnt].to=y; e[cnt].v=v; head[x]=cnt;} 19 void getroot(int x, int fa) 20 { 21 size[x]=1; f[x]=0; 22 for (int i=head[x];i;i=e[i].next) 23 { 24 if (e[i].to==fa || vis[e[i].to]) continue; 25 getroot(e[i].to,x); 26 size[x]+=size[e[i].to]; 27 f[x]=max(f[x],size[e[i].to]); 28 } 29 f[x]=max(f[x],sum-size[x]); 30 if (f[x]<f[root]) root=x; 31 } 32 void cal(int x, int fa) 33 { 34 if (dis[x]<=K) ans=min(ans,d[x]+t[K-dis[x]]); 35 for (int i=head[x];i;i=e[i].next) 36 { 37 if (e[i].to==fa || vis[e[i].to]) continue; 38 dis[e[i].to]=dis[x]+e[i].v; 39 d[e[i].to]=d[x]+1; 40 cal(e[i].to,x); 41 } 42 } 43 void add(int x, int fa, int flag) 44 { 45 if (dis[x]<=K) 46 { 47 if (flag) t[dis[x]]=min(t[dis[x]],d[x]); 48 else t[dis[x]]=n; 49 } 50 for (int i=head[x];i;i=e[i].next) 51 if (e[i].to!=fa && !vis[e[i].to]) 52 add(e[i].to,x,flag); 53 } 54 void work(int x) 55 { 56 vis[x]=1; t[0]=0; 57 for (int i=head[x];i;i=e[i].next) 58 { 59 if (vis[e[i].to]) continue; 60 d[e[i].to]=1; dis[e[i].to]=e[i].v; 61 cal(e[i].to,0); 62 add(e[i].to,0,1); 63 } 64 for (int i=head[x];i;i=e[i].next) 65 if (!vis[e[i].to]) add(e[i].to,0,0); 66 for (int i=head[x];i;i=e[i].next) 67 { 68 if (vis[e[i].to]) continue; 69 root=0; sum=size[e[i].to]; 70 getroot(e[i].to,0); 71 work(root); 72 } 73 } 74 int main() 75 { 76 n=ra(); K=ra(); 77 for (int i=1; i<=K; i++) t[i]=n; 78 for (int i=1; i<n; i++) 79 { 80 int x=ra()+1,y=ra()+1,v=ra(); 81 insert(x,y,v); insert(y,x,v); 82 } 83 ans=sum=f[0]=n; 84 getroot(1,0); 85 work(root); 86 if (ans!=n) printf("%d\n",ans); else printf("-1\n"); 87 return 0; 88 }