洛谷 P1084疫情控制题解--zhengjun
我去我TM没有去掉freopen调了一下午\(\cdots\)
思路
因为如果第\(i\)分钟可以控制住疫情,那么第\(i+1\)以及之后的都是可以的,所以,就可以二分了。
然后就是\(check\)函数如何写,有一个显而易见,就是每一个军队都要尽量靠近根节点,这样才会拦掉更多的点,所以,就要把每个军队向上提,如果可以到达根节点,就先放在一边,然后再处理,然后,因为到了根节点了,所以最优方案就是放在第二层的节点上,然后,就要看看这些时间够不够就可以了。
然后,因为向上提的过程中十分耗时间,于是,就可以用倍增的思想搞成\(log\)级别的。
代码
#include<bits/stdc++.h>
#define maxn 50001
using namespace std;
int n,m;
int a[maxn];
int head[maxn],to[maxn<<1],v[maxn<<1],nex[maxn<<1],k;
int t;
int f[maxn][20];
long long dist[maxn][20];
long long sum,l,r,mid;
void add(int x,int y,int z){
to[k]=y,v[k]=z;
nex[k]=head[x];
head[x]=k++;
}
int deep[maxn];
void bfs(){
queue<int> q;
q.push(1);
deep[1]=1;
while(!q.empty()){
int x=q.front();
q.pop();
for(int pos=head[x];pos!=-1;pos=nex[pos]){
int y=to[pos];
if(deep[y])continue;
q.push(y);
deep[y]=deep[x]+1;
f[y][0]=x;
dist[y][0]=v[pos];
for(int i=1;i<=t;i++){
f[y][i]=f[f[y][i-1]][i-1];
dist[y][i]=dist[y][i-1]+dist[f[y][i-1]][i-1];
}
}
}
}
int s[maxn];
struct zj{
long long t;
int x;
bool operator < (const zj &y)const{
return t<y.t;
}
}fre[maxn];
int dfs(int x){
if(s[x])return 1;
int ff=0;
for(int pos=head[x];pos!=-1;pos=nex[pos]){
if(deep[to[pos]]<deep[x])continue;
ff=1;
if(!dfs(to[pos]))return 0;
}
return ff;
}
int flag[maxn];
bool check(long long time){
memset(s,0,sizeof(s));
memset(flag,0,sizeof(flag));
int tot=0;
for(int i=1;i<=m;i++){
int x=a[i];
long long p=0;
for(int j=t;j>=0;j--){
if(f[x][j]>1&&p+dist[x][j]<=time){
p+=dist[x][j];
x=f[x][j];
}
}
if(f[x][0]==1&&p+dist[x][0]<=time)fre[++tot]=(zj){time-p-dist[x][0],x};
else s[x]=1;
}
for(int pos=head[1];pos!=-1;pos=nex[pos]){
if(!dfs(to[pos])){
flag[to[pos]]=1;
}
}
sort(fre+1,fre+1+tot);
long long aa[maxn],bb[maxn];
aa[0]=bb[0]=0;
for(int i=1;i<=tot;i++){
if(flag[fre[i].x]==1&&fre[i].t<dist[fre[i].x][0])flag[fre[i].x]=0;
else aa[++aa[0]]=fre[i].t;
}
for(int pos=head[1];pos!=-1;pos=nex[pos]){
if(flag[to[pos]]){
bb[++bb[0]]=v[pos];
}
}
if(aa[0]<bb[0])return 0;
sort(aa+1,aa+1+aa[0]);
sort(bb+1,bb+1+bb[0]);
for(int i=1,j=1;i<=aa[0];){
if(aa[i]>=bb[j])i++,j++;
else i++;
if(j>bb[0])return 1;
}
return 0;
}
int main(){
memset(head,-1,sizeof(head));
scanf("%d",&n);
t=log2(n);
for(int i=1;i<n;i++){
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
sum+=z;
add(x,y,z);
add(y,x,z);
}
bfs();
scanf("%d",&m);
for(int i=1;i<=m;i++)scanf("%d",&a[i]);
r=sum+1;l=-1;
while(l+1<r){
mid=(l+r)>>1;
if(check(mid))r=mid;
else l=mid;
}
if(r>sum)printf("-1");
else printf("%lld",r);
return 0;
}