[学习笔记]虚树
虚树
虚树可以应用于树形 \(DP\) 的加速。当题目规定查询点集的大小和 \(\le 10^5\) 时可以用虚树解决。
虚树的原理是在原树上重新建一棵树,使得树上只包含要询问的点和它们的 \(lca\)。
普通树形 \(DP\) 的时间复杂度为 \(O(n^2)\)。最坏形成一棵二叉树,点集大小为 \(n\),总点数为 \(nlogn\) 个,时间复杂度为 \(O(nlog(n))\)。
建树
我们先对原树进行 \(dfs\),预处理出 \(dfs\) 序,和倍增求 \(lca\) 所需的深度和父亲。
void dfs(int x,int fa){
dep[x]=dep[fa]+1;
dfn[x]=++top;
f[x][0]=fa;
for(int i=hea[x];i;i=nex[i]){
int t=to[i];
if(t!=fa){
an[t]=min(an[x],wa[i]);
dfs(t,x);
}
}
}
建树时先对询问点的 \(dfs\) 序排序,这样就能保证栈中 \(dfs\) 序单调递增。
先把 \(1\) 号节点入栈,注意要在入栈时初始化。
void ins(){
sort(h+1,h+k+1,cmp);
sta[tom=1]=1;
too[1].clear();
ans[1]=0;
v[1]=0;
for(int i=1;i<=k;i++){
if(h[i]==1) continue;
int lc=lca(h[i],sta[tom]);
if(lc!=sta[tom]){//当前元素不在栈中元素的链上。
while(dfn[lc]<dfn[sta[tom-1]]){
too[sta[tom-1]].push_back(sta[tom]);
tom--;
}//弹出多余的元素并建边,保留lca下的第一个元素。
if(dfn[lc]>dfn[sta[tom-1]]){
too[lc].clear();
ans[lc]=0;
v[lc]=0;
too[lc].push_back(sta[tom]);//加边
sta[tom]=lc;//将lca入栈。
}//lca不在栈中(如果在栈中lca就是栈中最后一个元素)。
else
too[lc].push_back(sta[tom--]);
}
//如果lca为栈顶元素,那么新加入点与栈中元素在一条链上,直接入栈。
sta[++tom]=h[i];//入栈当前元素
too[h[i]].clear();
ans[h[i]]=0;
v[h[i]]=0;
}
for(int i=1;i<tom;i++)
too[sta[i]].push_back(sta[i+1]);
ans[1]=0;//对栈中剩余元素加边。
v[1]=0;
}
剩余部分就是正常的树形 \(DP\) 啦。
例题
code
#include<iostream>
#include<algorithm>
#include<vector>
#include<cstdio>
using namespace std;
const int N=2.5e5+10;
int n,top,dfn[N],an[N],m,h[N],k,f[N][25],dep[N],sta[N],tom;
long long ans[N];
int tot,hea[N],nex[N<<1],to[N<<1],wa[N<<1];
bool v[N];
vector<int>too[N];
void add(int x,int y,int z){
to[++tot]=y;
wa[tot]=z;
nex[tot]=hea[x];
hea[x]=tot;
}
void dfs(int x,int fa){
dep[x]=dep[fa]+1;
dfn[x]=++top;
f[x][0]=fa;
for(int i=hea[x];i;i=nex[i]){
int t=to[i];
if(t!=fa){
an[t]=min(an[x],wa[i]);
dfs(t,x);
}
}
}
int lca(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
for(int i=22;i>=0;i--){
if(dep[f[x][i]]>=dep[y]){
x=f[x][i];
}
}
if(x==y) return x;
for(int i=22;i>=0;i--){
if(f[x][i]!=f[y][i]){
x=f[x][i];
y=f[y][i];
}
}
return f[x][0];
}
bool cmp(int x,int y){
return dfn[x]<dfn[y];
}
void ins(){
sort(h+1,h+k+1,cmp);
sta[tom=1]=1;
too[1].clear();
ans[1]=0;
v[1]=0;
for(int i=1;i<=k;i++){
// cout<<h[i]<<"@";
if(h[i]==1) continue;
int lc=lca(h[i],sta[tom]);
// cout<<h[i]<<" "<<sta[tom]<<" "<<lc<<"!!!!!\n";
if(lc!=sta[tom]){
while(dfn[lc]<dfn[sta[tom-1]]){
too[sta[tom-1]].push_back(sta[tom]);
// cout<<sta[tom]<<"!";
tom--;
// cout<<"!!";
}
if(dfn[lc]>dfn[sta[tom-1]]){
too[lc].clear();
ans[lc]=0;
v[lc]=0;
too[lc].push_back(sta[tom]);
// cout<<sta[tom]<<"!";
// cout<<sta[tom]<<"!!#";
sta[tom]=lc;
// cout<<"##";
}
else{
too[lc].push_back(sta[tom--]);
}
}
sta[++tom]=h[i];
too[h[i]].clear();
ans[h[i]]=0;
v[h[i]]=0;
}
// for(int i=1;i<=tom;i++) cout<<sta[i]<<"!";
for(int i=1;i<tom;i++){
too[sta[i]].push_back(sta[i+1]);
}
ans[1]=0;
v[1]=0;
}
void dp(int x){
for(int t:too[x]){
dp(t);
if(v[t]){
ans[x]=ans[x]+an[t];
}
else ans[x]+=min(ans[t],1ll*an[t]);
}
}
int main(){
scanf("%d",&n);
for(int i=1;i<n;i++){
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
add(x,y,z);
add(y,x,z);
}
an[1]=1e9;
dfs(1,0);
for(int i=1;i<=22;i++){
for(int j=1;j<=n;j++){
f[j][i]=f[f[j][i-1]][i-1];
}
}
scanf("%d",&m);
for(int i=1;i<=m;i++){
scanf("%d",&k);
for(int j=1;j<=k;j++){
scanf("%d",&h[j]);
}
ins();
for(int j=1;j<=k;j++){
v[h[j]]=1;
}
// for(int j=1;j<=n;j++){
// if(too[j].size())
// cout<<j<<" ";
// for(int t:too[j]){
// cout<<t<<" ";
// }
// if(too[j].size())
// cout<<endl;
// }
dp(1);
printf("%lld\n",ans[1]);
}
return 0;
}
code
#include<iostream>
#include<algorithm>
#include<vector>
#include<cstdio>
using namespace std;
const int N=100010;
int n,q,k,h[N],dfn[N],top,f[N][20],dep[N],sta[N],tom,ans,g[N];
bool v[N];
int tot,hea[N],nex[N<<1],to[N<<1];
vector<int>too[N];
void add(int x,int y){
to[++tot]=y;
nex[tot]=hea[x];
hea[x]=tot;
}
void dfs(int x,int fa){
f[x][0]=fa;
dep[x]=dep[fa]+1;
dfn[x]=++top;
for(int i=hea[x];i;i=nex[i]){
int t=to[i];
if(t!=fa){
dfs(t,x);
}
}
}
int lca(int x,int y){
if(dep[x]<dep[y])swap(x,y);
for(int i=17;i>=0;i--){
if(dep[f[x][i]]>=dep[y]){
x=f[x][i];
}
}
if(x==y) return x;
for(int i=17;i>=0;i--){
if(f[x][i]!=f[y][i]){
x=f[x][i];
y=f[y][i];
}
}
return f[x][0];
}
bool cmp(int x,int y){
return dfn[x]<dfn[y];
}
void ins(){
sort(h+1,h+k+1,cmp);
sta[tom=1]=1;
v[1]=0;
g[1]=0;
too[1].clear();
for(int i=1;i<=k;i++){
if(h[i]==1) continue;
int lc=lca(h[i],sta[tom]);
if(lc!=sta[tom]){
while(dfn[sta[tom-1]]>dfn[lc]){
too[sta[tom-1]].push_back(sta[tom]);
tom--;
}
if(lc!=sta[tom-1]){
v[lc]=0;
g[lc]=0;
too[lc].clear();
too[lc].push_back(sta[tom]);
sta[tom]=lc;
}
else{
too[sta[tom-1]].push_back(sta[tom]);
tom--;
}
}
sta[++tom]=h[i];
v[h[i]]=0;
g[h[i]]=0;
too[h[i]].clear();
}
for(int i=1;i<tom;i++){
too[sta[i]].push_back(sta[i+1]);
}
}
void dp(int x){
int num=0;
for(int t:too[x]){
dp(t);
// if(x==15) cout<<t<<"@$";
g[x]+=g[t];
if(g[t]!=0) num++;
if(v[x]==1&&v[t]==1&&dep[t]==dep[x]+1){
ans=-1;
}
}
if(v[x]) num++;
if(num>1&&ans!=-1){
if(v[x]==1) ans+=g[x];
else ans++;
g[x]=0;
}
if(v[x]) g[x]++;
// cout<<x<<" "<<ans[x]<<" "<<num<<" "<<g[x]<<"!!\n";
// if(x==11) cout<<v[x]<<" "<<num<<"@#%";
}
int main(){
scanf("%d",&n);
for(int i=1;i<n;i++){
int x,y;
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dfs(1,0);
for(int i=1;i<=17;i++){
for(int j=1;j<=n;j++){
f[j][i]=f[f[j][i-1]][i-1];
}
}
scanf("%d",&q);
for(int i=1;i<=q;i++){
scanf("%d",&k);
for(int j=1;j<=k;j++){
scanf("%d",&h[j]);
}
ins();
for(int j=1;j<=k;j++){
v[h[j]]=1;
}
ans=0;
// for(int j=1;j<=n;j++){
// if(too[j].size()) cout<<j<<"@";
// for(int t:too[j]){
// cout<<t<<" ";
// }
// if(too[j].size()) cout<<endl;
// }
dp(1);
printf("%d\n",ans);
}
return 0;
}