bzoj3572[HNOI2014]世界树
拖了三个月,终于A了
前几天模拟赛暴露了我不敢写虚树的垃圾本质
实际上我当时不写虚树的暴力做法也是能过测试数据的奈何码力不足没调出来
也许这就是蒟蒻.jpg
对于这道题,首先我们把虚树搞出来.然后考虑每个点能够在树上控制的范围,必然可以表示成树上的某一棵子树中删掉一些被它包含的子树(可能不删)之后剩下的部分.那么我们对虚树上每个点求出最近的关键点,然后对虚树上每条边处理一波就吼了.
也许你看出来了这篇题解没有好好写23333,看代码吧,不明白可以在评论问
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
const int maxn=300005;
struct edge{
int to,next,w;
}lst[maxn<<1],lst2[maxn<<1];int len=1,first[maxn],len2=1,first2[maxn];
void addedge(int a,int b){
lst[len].to=b;lst[len].next=first[a];first[a]=len++;
}
void addedge2(int a,int b,int w){//printf("%d %d %d\n",a,b,w);
lst2[len2].to=b;lst2[len2].next=first2[a];lst2[len2].w=w;first2[a]=len2++;
}
int prt[maxn],depth[maxn],p[maxn][20];
int dfn[maxn],T,sz[maxn];
void dfs(int x){
dfn[x]=++T;
p[x][0]=prt[x];sz[x]=1;
for(int j=0;p[x][j];++j)p[x][j+1]=p[p[x][j]][j];
for(int pt=first[x];pt;pt=lst[pt].next){
if(lst[pt].to==prt[x])continue;
prt[lst[pt].to]=x;
depth[lst[pt].to]=depth[x]+1;
dfs(lst[pt].to);
sz[x]=sz[x]+sz[lst[pt].to];
}
}
int lca(int u,int v){
if(depth[u]<depth[v])swap(u,v);
int dlt=depth[u]-depth[v];
for(int i=0;dlt;dlt>>=1,++i){
if(dlt&1)u=p[u][i];
}
if(u==v)return u;
for(int j=19;j>=0;--j){
if(p[u][j]!=p[v][j]){
u=p[u][j];v=p[v][j];
}
}
return p[u][0];
}
int m,tot;
int point[maxn];
int seq[maxn];
bool cmp(const int &a,const int &b){
return dfn[a]<dfn[b];
}
int clk;
int ok[maxn];//ok[i]==clk?
void build(){
static int stk[maxn],top;
top=0;
++clk;
for(int i=1;i<=m;++i)seq[i]=point[i],ok[seq[i]]=clk;
bool is1=false;
for(int i=1;i<=m;++i)if(seq[i]==1)is1=true;
sort(seq+1,seq+m+1,cmp);
int lo=1;
stk[top++]=1;
if(is1)lo=2;
for(int i=lo;i<=m;++i){
int Lca=lca(seq[i],stk[top-1]);
if(Lca==stk[top-1]){
stk[top++]=seq[i];
}else{
while(top>=2&&depth[stk[top-2]]>=depth[Lca]){
addedge2(stk[top-1],stk[top-2],depth[stk[top-1]]-depth[stk[top-2]]);
addedge2(stk[top-2],stk[top-1],depth[stk[top-1]]-depth[stk[top-2]]);
--top;
}
if(Lca!=stk[top-1]){
addedge2(stk[top-1],Lca,depth[stk[top-1]]-depth[Lca]);
addedge2(Lca,stk[top-1],depth[stk[top-1]]-depth[Lca]);
stk[top-1]=Lca;point[++tot]=Lca;
}
stk[top++]=seq[i];
}
}
for(int i=0;i+1<top;++i){
addedge2(stk[i],stk[i+1],depth[stk[i+1]]-depth[stk[i]]);
addedge2(stk[i+1],stk[i],depth[stk[i+1]]-depth[stk[i]]);
}
}
int p1[maxn],p2[maxn],d1[maxn],d2[maxn];
void upd(int x,int y,int d){
if(y==0)return;
if(d<d1[x]||(d==d1[x]&&y<p1[x])){
d2[x]=d1[x];p2[x]=p1[x];
d1[x]=d;p1[x]=y;
}else if(d<d2[x]||(d==d1[x]&&y<p2[x])){
d2[x]=d;p2[x]=y;
}
}
void getnear(int x,int p){
p1[x]=p2[x]=0;d1[x]=d2[x]=0x3f3f3f3f;
if(ok[x]==clk)p1[x]=x,d1[x]=0;
for(int pt=first2[x];pt;pt=lst2[pt].next){
if(lst2[pt].to==p)continue;
getnear(lst2[pt].to,x);
upd(x,p1[lst2[pt].to],d1[lst2[pt].to]+lst2[pt].w);
}
}
void getnear2(int x,int p,int w){
if(p){
if(p1[p]==p1[x])upd(x,p2[p],d2[p]+w);
else upd(x,p1[p],d1[p]+w);
}
for(int pt=first2[x];pt;pt=lst2[pt].next){
if(lst2[pt].to==p)continue;
getnear2(lst2[pt].to,x,lst2[pt].w);
}
}
typedef pair<int,int> pr;
int uplim[maxn];
vector<pr> forbid[maxn];
int jump(int x,int k){
for(int j=19;j>=0;--j){
if(k>=(1<<j)){
k-=(1<<j);x=p[x][j];
}
}
return x;
}
int calc(int x){
sort(forbid[x].begin(),forbid[x].end());
int l=dfn[uplim[x]],r=dfn[uplim[x]]+sz[uplim[x]]-1;
int ans=0;
for(vector<pr>::iterator pt=forbid[x].begin();pt!=forbid[x].end();++pt){
if(pt->first-1>=l)ans+=pt->first-l;
l=pt->second+1;
}
if(l<=r)ans+=(r-l+1);
return ans;
}
void work(){
scanf("%d",&m);tot=m;
for(int i=1;i<=m;++i)scanf("%d",&point[i]);
build();
getnear(1,0);
getnear2(1,0,0);
for(int i=1;i<=tot;++i)uplim[seq[i]]=1,forbid[seq[i]].clear();
for(int i=1;i<=tot;++i){
int x=point[i];
for(int pt=first2[x];pt;pt=lst2[pt].next){
if(depth[lst2[pt].to]<depth[x]){
int p=lst2[pt].to;
if(p1[p]!=p1[x]){
int Dp=d1[p],Dx=d1[x];
int P=p1[p],X=p1[x];
int g;
int dis=depth[x]-depth[p];
int l=1,r=dis-1;
while(l<=r){
int mid=(l+r)>>1;
int D1=Dp+mid,D2=Dx+(dis-mid);
if((D1>D2)||(D1==D2&&X<P)){
r=mid-1;
}else{
l=mid+1;
}
}
g=jump(x,dis-l);
if(depth[uplim[X]]<depth[g])uplim[X]=g;
forbid[P].push_back(pr(dfn[g],dfn[g]+sz[g]-1));
}
}
}
}
for(int i=1;i<=m;++i)printf("%d%c",calc(point[i]),(i==m)?'\n':' ');
len2=1;
for(int i=1;i<=tot;++i)first2[point[i]]=0;first2[1]=0;//printf("\n\n");
}
int main(){
int n;scanf("%d",&n);
for(int i=1,a,b;i<n;++i){
scanf("%d%d",&a,&b);addedge(a,b);addedge(b,a);
}
depth[1]=1;
dfs(1);
int q;scanf("%d",&q);
while(q--){
work();
}
return 0;
}