bzoj 2599
还是点对之间的问题,果断上点分治
同样,把一条路径拆分成经过根节点的两条路径,对不经过根节点的路径递归处理
然后,我们逐个枚举根节点的子树,计算出子树中某一点到根节点的距离,然后在之前已经处理过的点中找,看有没有距离之和等于k的,如果有就取最小值(这里用桶维护即可)
然后再把这个子树内的信息扔进桶里,计算下一棵子树即可
但是注意,在递归处理之前需要把桶清空!
然后就没啥了
不合法的情况就是无法更新出答案,输出-1即可
#include <cstdio> #include <cmath> #include <cstring> #include <cstdlib> #include <iostream> #include <algorithm> #include <queue> #include <stack> using namespace std; const int inf=0x3f3f3f3f; struct Edge { int next; int to; int val; }edge[400005]; int head[200005]; int cnt=1; int n,k; int s,rt; int siz[200005]; int maxp[200005]; bool vis[200005]; int dep[200005]; int dis[200005]; int has[1000005]; int ans=0x3f3f3f3f; void init() { memset(head,-1,sizeof(head)); cnt=1; } void add(int l,int r,int w) { edge[cnt].next=head[l]; edge[cnt].to=r; edge[cnt].val=w; head[l]=cnt++; } void get_rt(int x,int fa) { siz[x]=1,maxp[x]=0; for(int i=head[x];i!=-1;i=edge[i].next) { int to=edge[i].to; if(to==fa||vis[to])continue; get_rt(to,x); siz[x]+=siz[to]; maxp[x]=max(maxp[x],siz[to]); } maxp[x]=max(maxp[x],s-siz[x]); if(maxp[x]<maxp[rt])rt=x; } void calc(int x,int fa) { if(dis[x]<=k)ans=min(ans,dep[x]+has[k-dis[x]]); for(int i=head[x];i!=-1;i=edge[i].next) { int to=edge[i].to; if(to==fa||vis[to])continue; dep[to]=dep[x]+1,dis[to]=dis[x]+edge[i].val; calc(to,x); } } void update(int x,int fa) { if(dis[x]<=k)has[dis[x]]=min(has[dis[x]],dep[x]); for(int i=head[x];i!=-1;i=edge[i].next) { int to=edge[i].to; if(to==fa||vis[to])continue; update(to,x); } } void erase(int x,int fa) { if(dis[x]<=k)has[dis[x]]=inf; for(int i=head[x];i!=-1;i=edge[i].next) { int to=edge[i].to; if(to==fa||vis[to])continue; erase(to,x); } } void solve(int x) { vis[x]=1,has[0]=0; for(int i=head[x];i!=-1;i=edge[i].next) { int to=edge[i].to; if(vis[to])continue; dep[to]=1,dis[to]=edge[i].val; calc(to,0); update(to,0); } for(int i=head[x];i!=-1;i=edge[i].next) { int to=edge[i].to; if(vis[to])continue; erase(to,0); } for(int i=head[x];i!=-1;i=edge[i].next) { int to=edge[i].to; if(vis[to])continue; s=siz[to],rt=0,maxp[rt]=inf; get_rt(to,0); solve(rt); } } int main() { scanf("%d%d",&n,&k); init(); for(int i=1;i<=k;i++)has[i]=n; for(int i=1;i<n;i++) { int x,y,z; scanf("%d%d%d",&x,&y,&z); x++,y++; add(x,y,z),add(y,x,z); } ans=maxp[rt]=s=n; get_rt(1,0); solve(rt); ans=(ans==n)?-1:ans; printf("%d\n",ans); return 0; }