虚树总结
之前学了一些算法,没有写算法总结,未来会陆续补一些。
前置知识:树形 \(dp\),\(lca\),\(dfs\) 序。
我们考虑 \([HEOI2014]\) 大工程 这道题。
显而易见,假如这道题只有一次询问,我们可以直接树形 \(dp\),快速求出答案,时间复杂度 \(O(n)\)。
但是,梦想是梦想,现实是现实,这题多组询问,假如一遍一遍求,时间复杂度 \(O(nm)\)。
同时,由于改变的是选择的点,所以你就可以提前放弃 \(up\ and\ down\) 了……
但是这道题有一个很有意思的性质:\(\sum k\le 2n\)。
这也就意味着,假如树的点数只有 \(k\) 个,\(O(\sum k)\) 的时间复杂度是可以 \(AC\) 的,这就很好。
那么,接下来的问题就是:合理转化这棵树,使得,我们可以单次用 \(O(k)\) 或 \(O(k\log k)\) 的时间复杂度建立一棵点数级别为 \(O(k)\) 的虚树(实际上还可以用单调栈建立虚树,但这里讲我比较拿手的 \(lca\) 建立虚树)。
我们考虑一下,一棵虚树中需要维护哪些点。
-
显而易见,题目中类似于“选中点”这种的关键点一定要选。
-
为了方便,通常我们也会选择树的根节点(便于维护答案,不选也行)。
-
其他的点都只是路过,但是两个关键点的共同祖先(即 \(lca\))是一定要选的,因为要通过他转移答案。
那这样很好想到暴力 \(O(k^2)\) 枚举 \(lca\),当然这样时间复杂度绝对会爆炸。
我们考虑下面这种方法:
-
对关键点序列 \(a\) 按照 \(dfs\) 序排序。
-
将相邻两个关键点的 \(lca\) 插入序列。
-
再对序列排一次序。
-
\(a_i\) 与 \(lca(a_{i-1},a_i)\) 连边。
时间复杂度 \(O(k\log k)\)。
对于本题,设 \(dp_{u,0/1/2}\) 表示第 \(u\) 个节点的子树内,所有选中节点到它的距离之和/选中节点中到它的最短距离/选中节点中到它的最长距离,\(as_{u,0/1/2}\) 则代表对于这个子树,题目所问问题的三个答案,\(i1,i2\) 分别为使 \(dp_{u,1/2}\) 取极值的 \(v\)。
则 \(dp\) 方程为:
时间复杂度 \(O(\sum k\log k)\)。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1000005;
int n,l,q,dfn[N],fa[N][21],p[N],a[N];
int m,r,t,k,h[N],nxt[N*2],to[N*2];
int d[N],nex[N*2],go[N*2],dep[N],b[N*2];
ll dp[3][N],as[3][N],c[N*2],sz[N];
int cmp(int x,int y){
return dfn[x]<dfn[y];
}void ad(int x,int y){
to[++m]=y;nxt[m]=h[x];h[x]=m;
}void add(int x,int y,int z){
go[++r]=y;c[r]=z;
nex[r]=d[x];d[x]=r;
}void dfs(int x,int f){
dep[x]=dep[f]+1;
fa[x][0]=f;dfn[x]=++l;
for(int i=0;i<20;i++)
fa[x][i+1]=fa[fa[x][i]][i];
for(int i=h[x];i;i=nxt[i])
if(f!=to[i]) dfs(to[i],x);
}int lca(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
for(int i=20;~i;i--)
if(dep[x]-dep[y]>=(1<<i))
x=fa[x][i];
if(x==y) return x;
for(int i=20;~i;i--)
if(fa[x][i]!=fa[y][i])
x=fa[x][i],y=fa[y][i];
return fa[x][0];
}int dis(int x,int y){
return dep[x]+dep[y]-2*dep[lca(x,y)];
}void dp_(int x,int f){
ll i1=0,i2=0;
as[1][x]=1e18;
if(p[x]==2) sz[x]=1;
else dp[1][x]=1e18;
for(int i=d[x];i;i=nex[i]){
int y=go[i];
if(y==f) continue;
dp_(y,x);sz[x]+=sz[y];
dp[0][x]+=c[i]*sz[y]+dp[0][y];
if(dp[1][x]>dp[1][y]+c[i])
dp[1][x]=dp[1][y]+c[i],i1=y;
if(dp[2][x]<dp[2][y]+c[i])
dp[2][x]=dp[2][y]+c[i],i2=y;
}if(p[x]==2) as[2][x]=dp[2][x];
for(int i=d[x];i;i=nex[i]){
int y=go[i];if(y==f) continue;
as[0][x]+=as[0][y]+(c[i]*sz[y]+dp[0][y])*(sz[x]-sz[y]);
as[1][x]=min(as[1][y],as[1][x]);
as[2][x]=max(as[2][y],as[2][x]);
if(i1!=y)
as[1][x]=min(as[1][x],dp[1][x]+dp[1][y]+c[i]);
if(i2!=y)
as[2][x]=max(as[2][x],dp[2][x]+dp[2][y]+c[i]);
}
}int main(){
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin>>n;
for(int i=1,x,y;i<n;i++)
cin>>x>>y,ad(x,y),ad(y,x);
dfs(1,0);
cin>>q;while(q--){
cin>>k;
for(int i=1;i<=k;i++)
cin>>a[i],b[++t]=a[i],p[a[i]]=2;
sort(a+1,a+k+1,cmp);
for(int i=1;i<k;i++){
int x=lca(a[i],a[i+1]);
if(!p[x]) p[x]=1,b[++t]=x;
}sort(b+1,b+t+1,cmp);
for(int i=1;i<t;i++){
int lc=lca(b[i],b[i+1]);
add(lc,b[i+1],dep[b[i+1]]-dep[lc]);
add(b[i+1],lc,dep[b[i+1]]-dep[lc]);
}int rt=lca(b[1],b[2]);dp_(rt,0);
cout<<as[0][rt]<<" "<<as[1][rt]<<" "<<as[2][rt]<<"\n";
for(int i=1;i<=t;i++){
p[b[i]]=sz[b[i]]=d[b[i]]=0;
dp[0][b[i]]=dp[1][b[i]]=dp[2][b[i]]=0;
as[0][b[i]]=dp[1][b[i]]=as[2][b[i]]=0;
}for(int i=1;i<=r;i++)
nex[i]=go[i]=c[i]=0;
r=t=0;
}return 0;
}