[HNOI2014]世界树
题目大意
给出一棵树,然后每次询问给出若干个关键点,每个点包含于距离它最近的关键点中编号最小的那个点,求每个关键点包含几个点。
对于 \(100\%\) 的数据,\(n\le 3\times 10^5\),\(q\le 3\times 10^5\),\(\sum\limits_{i=1}^qm_i\le 3\times 10^5\)。
Solution
看到树上 \(dp\),再加上关键点和多组询问这些信息,容易想到虚树。(不会虚树见我的学习笔记)
所以我们套路地建出虚树,然后呢?应当先考虑在原树上暴力怎么做。实际上就是对于每个点求出距离它最近的那个关键点,然后那个关键点的答案加一。那在虚树上实际上就是每条边都有边权,然后带着边权跑一次上面的暴力就可以了。(如果不会求每个点距离最近的关键点,建议重学换根大炮)
这样,我们把所有在虚树内的所有点对答案的贡献算完了。接下来我们考虑不在虚树上的点。可以发现,这些点分成两种,第一种是在虚树边上被略去的点,第二种是整棵子树中都没有关键点于是被省去的点。这些点可以通过倍增等奇技淫巧算出对答案的贡献,但是码量巨大并且细节巨大多。(反正我写了一遍就挂惨了)
于是我们考虑能不能不要算那么麻烦,这些点包括在虚树内的点的贡献可不可以用一种更加优美的方式,给它算出来。\(\texttt{C}\color{red}{\texttt{3H5ClO}}\) 巨佬在他的博客中给出了另一个角度来做这个问题。我们把点从属于一个关键点的过程看成是以关键点为中心向外染色。也就是说,我们可以考虑从上到下依次加入关键点,看对答案有什么影响。
具体地,对于第一个关键点,它起初的答案是整棵树,所以答案是 siz[1]
。接下来,在预处理完虚树上每个点从属的关键点之后,我们考虑虚树上的一条边。这条边上有若干个点,如果说这条边的两端所属的关键点不同,那么这条边上必然有一个分界点,上面的点依然对上面的关键点产生贡献,下面的点对新加入的这个关键点产生贡献。于是我们可以倍增求出这个分界的位置,然后把上面的关键点的答案减去下方子树大小,再把下方子树大小加入到这个新的关键的答案中去,这样不断地做下去。
具体看张图:
本来这里所有的点都对深度小的那个点的关键点,假如说我们倍增出来从红色叉叉那里断开,那么下面的所有点的贡献要从上面的关键点中去掉,然后再加入到下面的关键点的答案中去。
实在不行就康代码。
Code
#include<bits/stdc++.h>
#define ll long long
#define inf (1<<30)
#define INF (1ll<<60)
#define pb push_back
#define pii pair<int,int>
#define mkp make_pair
#define fi first
#define se second
#define rep(i,j,k) for(int i=(j);i<=(k);i++)
#define per(i,j,k) for(int i=(j);i>=(k);i--)
#define pt(a) cerr<<#a<<'='<<a<<' '
#define pts(a) cerr<<#a<<'='<<a<<'\n'
using namespace std;
pii pmin(pii a,pii b){return a<b?a:b;}
const int MAXN=6e5+10;
vector<int> e[MAXN];
int dep[MAXN],ldfn[MAXN],rdfn[MAXN],siz[MAXN],f[20][MAXN],tot;
void pdfs(int x,int fa){
ldfn[x]=++tot;f[0][x]=fa;siz[x]=1;
rep(i,1,19) f[i][x]=f[i-1][f[i-1][x]];
for(int s:e[x]){
if(s==fa) continue;
dep[s]=dep[x]+1;
pdfs(s,x);siz[x]+=siz[s];
}rdfn[x]=tot;
}
int LCA(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
int dlt=dep[x]-dep[y];
rep(i,0,19) if((dlt>>i)&1) x=f[i][x];
if(x==y) return x;
per(i,19,0) if(f[i][x]^f[i][y]) x=f[i][x],y=f[i][y];
return f[0][x];
}
pii dp[MAXN];
int h[MAXN],stk[MAXN],isk[MAXN],ans[MAXN],rk[MAXN];
vector<int> ve[MAXN];
void dfs1(int x){
if(isk[x]) dp[x]=mkp(0,x);
else dp[x]=mkp(inf,0);
for(int s:ve[x]){
dfs1(s);
dp[x]=pmin(dp[x],mkp(dp[s].fi+dep[s]-dep[x],dp[s].se));
}
}
void dfs2(int x){
for(int s:ve[x]){
dp[s]=pmin(dp[s],mkp(dp[x].fi+dep[s]-dep[x],dp[x].se));
dfs2(s);
}
}
void dfs(int x){
for(int s:ve[x]){
int u=dp[x].se,v=dp[s].se,t=s;
if(u^v){
int d=dep[v]-(dep[u]+dep[v]-2*dep[LCA(u,v)]-(v>u))/2;
per(i,19,0) if(dep[f[i][t]]>=d) t=f[i][t];
ans[u]-=siz[t];
ans[v]+=siz[t];
}dfs(s);
}
}
void init(int x){
ans[x]=isk[x]=0;
for(int s:ve[x]) init(s);
ve[x].clear();
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
int n;cin>>n;
rep(i,2,n){
int u,v;cin>>u>>v;
e[u].pb(v);e[v].pb(u);
}pdfs(1,0);
int Q;cin>>Q;
while(Q--){
int m;cin>>m;
rep(i,1,m) cin>>h[i],isk[h[i]]=1,rk[i]=h[i];
sort(h+1,h+1+m,[&](int x,int y){return ldfn[x]<ldfn[y];});
int M=m;
rep(i,2,m) h[++M]=LCA(h[i-1],h[i]);
sort(h+1,h+1+M,[&](int x,int y){return ldfn[x]<ldfn[y];});
M=unique(h+1,h+1+M)-h-1;
int top=0;
rep(i,1,M){
while(top&&rdfn[stk[top]]<ldfn[h[i]]) top--;
if(top) ve[stk[top]].pb(h[i]);
stk[++top]=h[i];
}
dfs1(h[1]);dfs2(h[1]);
ans[dp[h[1]].se]=siz[1];
dfs(h[1]);
rep(i,1,m) cout<<ans[rk[i]]<<' ';cout<<'\n';
init(h[1]);
}
return 0;
}