[ POI 2011 ] Dynamite
\(\\\)
\(Description\)
一棵\(N\)个节点的树,树上有\(M\)个节点是关键点,选出\(K\)个特殊点,使得所有关键点到特殊点的距离中最大的最小,输出最大值最小为多少。
- \(N\in [1,3\times 10^5]\),\(M,K\in [1,N]\)
\(\\\)
\(Solution\)
神仙树形DP验证方式
- 首先最大值最小一眼二分,关键在于怎么\(check\)这个二分得到的最远距离。
- 正解树型\(DP\),假设所有子树已经处理完毕,显然可以靠贡献和需求分为四类;
- \(0\):可向上贡献,子树内已全部覆盖
- \(1\):可向上贡献,子树内有待覆盖节点
- \(2\):无可向上覆盖,子树内已全部覆盖
- \(3\):无可向上覆盖,子树内有待覆盖节点
- 有趣的是发现\(1\)和\(3\)两类子树根节点的情况可合并,即\(1\)类节点的贡献是无意义的,为什么呢?因为子树内的点无法被覆盖,证明有贡献的点从根节点再往外的贡献长度是小于待覆盖节点到子树根节点距离的。而子树外的点若能够满足覆盖子树内待覆盖节点的要求,则外部节点从当前子树根节点延申出的长度为二倍的二分答案再减掉待覆盖节点到子树根节点的距离,显然比上一个长。于是点分类改为:
- \(0\):可向上贡献,子树内已全部覆盖
- \(1\):无可向上覆盖,子树内有待覆盖节点
- \(2\):无可向上覆盖,子树内已全部覆盖
- 设\(f[i]\)表示\(i\)号节点:向上最长贡献距离\((1)/\)子树内待覆盖节点到\(i\)号节点最大距离\((2)/0(3)\)。
- 然后就考虑当前节点的状态。统计所有有贡献的节点在根节点处的最大延申距离,即\(mx_1=max(mx_1,f[i]-1)\);统计所有待覆盖节点到当前根节点处最大距离,即\(mx_2=max(mx_2,f[i]+1)\)。
- 分情况讨论当前节点状态,注意若子树内待覆盖节点距离等于二分长度,当前点应设为特殊点,同时注意特判\(2\)类节点即可。
- 还要注意在树型\(DP\)后若\(1\)号节点是待覆盖的使用节点数要\(+1\)。
\(\\\)
\(Code\)
#include<cmath>
#include<cstdio>
#include<cctype>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#define N 300010
#define R register
#define gc getchar
using namespace std;
int n,m,tot,cnt,d[N],hd[N];
int f[N],tag[N];
struct edge{int to,nxt;}e[N<<1];
inline void add(int u,int v){
e[++tot].to=v; e[tot].nxt=hd[u]; hd[u]=tot;
}
inline int rd(){
int x=0; bool f=0; char c=gc();
while(!isdigit(c)){if(c=='-')f=1;c=gc();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=gc();}
return f?-x:x;
}
inline void dfs(int u,int fa,int lim){
int mn=-1,mx=d[u]-1;
for(R int i=hd[u];i;i=e[i].nxt)
if(e[i].to!=fa) dfs(e[i].to,u,lim);
for(R int i=hd[u];i;i=e[i].nxt)
if(e[i].to!=fa){
if(tag[e[i].to]==0) mn=max(mn,f[e[i].to]-1);
else if(tag[e[i].to]==1) mx=max(mx,f[e[i].to]+1);
}
if(mn<mx){
if(mx==lim){++cnt; f[u]=lim;tag[u]=0;}
else{f[u]=mx;tag[u]=1;}
}
else if(mn!=-1){f[u]=mn; tag[u]=0;}
else{f[u]=0;tag[u]=2;}
}
inline bool valid(int x){
cnt=0; dfs(1,0,x);
if(tag[1]==1) ++cnt;
return cnt<=m;
}
int main(){
n=rd(); m=rd();
for(R int i=1;i<=n;++i) d[i]=rd();
for(R int i=1,u,v;i<n;++i){
u=rd(); v=rd(); add(u,v); add(v,u);
}
int l=0,r=n;
while(l<r){
int mid=((l+r)>>1);
valid(mid)?r=mid:l=mid+1;
}
printf("%d\n",l);
return 0;
}