ABC298Ex(2)
多次询问 \(L,R\),求 \(\sum\limits_{i}\min(d(i,L),d(i,R))\)。
不失一般性的令 \(dep_L\ge dep_R\)。
考虑 \(i\) 到 \(L/R\) 的路径是怎样的。一定是 \(i\) 到 \(L\rightarrow\) 上的某一点 \(x\) 再到 \(L/R\)。
如果按照每个点到达 \(L/R\) 对其进行染色,则每种颜色都只有一个联通块,换句话说,颜色是树的一个划分。
考虑 \(L\rightarrow R\) 路径上的染色情况,一定有一个阈值 \(m\),满足路径上到 \(L\) 距离不超过 \(m\) 的点到达 \(L\),其余到达 \(R\)。利用倍增求出 \(L\) 的树上 \(m\) 级祖先 \(x\),设 \(\operatorname{lca}(L,R)=k\),则 \(x\) 必然位于 \(L\rightarrow k\) 的路径上(保证了 \(dep_L\ge dep_R\))。不难发现所有位于 \(x\) 子树内的点都会走到 \(L\),其余都会走到 \(R\)。
但是如果 \(x=k\),由于 \(R\) 也在 \(x\) 的子树中,因此会出错。但是此时当且仅当 \(d(L,x)=d(x,R)\),我们选择 \(L\rightarrow R\) 路径上 \(x\) 的前一个点,即 \(L\) 的 \(m-1\) 级祖先即可。
因此答案就是 \(\sum\limits_{i\in\text{subtree}_x}d(i,L)+\sum\limits_{i\notin\text{subtree}_x}d(i,R)\)。
\(i\notin\text{subtree}_x\) 并不好处理,将答案改写为 \(\sum\limits_{i\in\text{subtree}_x}d(i,L)+\sum\limits_{i}d(i,R)-\sum\limits_{i\in\text{subtree}_x}d(i,R)\)。
此时式子已经相较于一开始变得更好处理,只需要求出一棵子树的所有点到子树内某点的距离和即可。即 \(\sum\limits_{i\in\text{subtree}_x}d(1,i)+d(1,p)-2\times d(1,\operatorname{lca}(i,p))\)。
\(\sum\limits_{i\in\text{subtree}_x}d(1,i)+d(1,p)\) 是好求的,考虑怎么求 \(\sum\limits_{i\in\text{subtree}_x}d(1,\operatorname{lca}(i,p))\)。
对 \(\operatorname{lca}(i,p)\) 拆贡献。设 \(p\rightarrow x\) 经过的点为 \(\{\alpha_z\}\)。则为 \(\sum\limits_{1\le j\le z}d(1,\alpha_j)\times(s_{\alpha_j}-s_{\alpha_{j-1}})\)。
其中 \(s_i\) 代表 \(i\) 的子树大小,\(s_{\alpha_0}=0\)。
充分发挥人类智慧,改写为 \(s_{\alpha_z}\times d(1,\alpha_z)+\sum\limits_{1\le i<z}s_{\alpha_i}\times(d(1,\alpha_i)-d(1,\alpha_{i+1}))\)。
发现路径上的点满足 \(\alpha_i=fa_{\alpha_{i-1}}\),且 \(\alpha_z=x\)。因此得到 \(s_x\times d(1,x)+\sum\limits_{1\le i<z}s_{\alpha_i}\)。维护 \(t_i=\sum\limits_{1\rightarrow i}s_i\),则得到 \(s_x\times d(1,x)+t_p-t_x\)。
注意,如果 \(p\) 在 \(x\) 的子树外,则应该加上 \(d(p,x)\times s_x\)。
如果令 \(f(x,p)\) 代表上面所写的 \(\sum\limits_{i\in\text{subtree}_x}d(i,p)\)。则答案为 \(f(x,L)+f(1,R)-f(x,R)\)。
时间复杂度 \(O(q\log n)\),瓶颈在于倍增求树上 \(k\) 级祖先。
#include <bits/stdc++.h>
#define int long long
#define debug(...) fprintf(stderr,##__VA_ARGS__)
bool Mbe;
const int inf=1e18;
const int maxn=2e5+10;
std::vector<int>a[maxn];
int t[maxn],dep[maxn],s[maxn],f[maxn][20],lg[maxn],n,q,h[maxn];
template<typename T,typename I>
void chkmin(T &a,I b){
a=std::min(a,b);
}
template<typename T,typename I>
void chkmax(T &a,I b){
a=std::max(a,b);
}
namespace ST{
void dfs(int p,int fa){
f[p][0]=fa,dep[p]=dep[fa]+1;
s[p]=1;
h[p]=dep[p];
for(int i:a[p]){
if(i==fa) continue;
dfs(i,p);
s[p]+=s[i];
h[p]+=h[i];
}
// t[p]=t[fa]+s[p];
}
void dfs2(int p,int fa){
t[p]=t[fa]+s[p];
for(int i:a[p]){
if(i==fa) continue;
dfs2(i,p);
}
}
void init(int rt){
dfs(rt,rt);
dfs2(rt,rt);
lg[0]=-1;
for(int i=1;i<=n;i++) lg[i]=lg[i>>1]+1;
for(int j=1;j<20;j++)
for(int i=1;i<=n;i++) f[i][j]=f[f[i][j-1]][j-1];
}
int kth(int p,int k){
for(int i=19;i>=0;i--)
if(k>=(1ll<<i)) k-=(1ll<<i),p=f[p][i];
return p;
}
int lca(int x,int y){
if(dep[x]<dep[y]) std::swap(x,y);
for(int i=19;i>=0;i--)
if(dep[f[x][i]]>=dep[y]) x=f[x][i];
if(x==y) return x;
for(int i=19;i>=0;i--)
if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
int dis(int x,int y){
return dep[x]+dep[y]-dep[lca(x,y)]*2;
}
int query(int x,int p){//求出 x 子树中的所有点到 p 的距离和
int res=s[x]*(dep[x]-1)+t[p]-t[x];
// debug("qwq res=%lld x=%lld p=%lld\n",res,x,p);
if(lca(x,p)!=x) res=(dep[x]-1)*s[x];
res=h[x]-s[x]+s[x]*dis(1,(lca(x,p)==x?p:x))-2*res;
if(lca(x,p)!=x) res+=s[x]*dis(p,x);
// debug("res=%lld x=%lld h%lld s=%lld lca=%lld\n",res,x,h[x],s[x],lca(x,p));
return res;
}
}
void check_s_t(){
debug("checking\n");
for(int i=1;i<=n;i++) debug("i=%lld s=%lld t=%lld h=%lld\n",i,s[i],t[i],h[i]);
}
bool Men;
signed main(){
debug("%.8f\n",((&Men-&Mbe)/1048576.0));
std::cin>>n;
for(int i=1;i<n;i++){
int u,v;
std::cin>>u>>v;
a[u].push_back(v),a[v].push_back(u);
}
ST::init(1);
// debug("check:%lld\n",ST::query(4,1));
// check_s_t();
std::cin>>q;
while(q--){
int x,y;
std::cin>>x>>y;
if(dep[x]<dep[y]) std::swap(x,y);
int k=ST::lca(x,y);//x y 的 lca
// debug("k=%lld\n",k);
int m=ST::dis(x,y)/2;//阈值
int _x=ST::kth(x,m);//树上阈值级祖先
// debug("lca=%lld 阈值=%lld kth=%lld Q1=%lld Q2=%lld Q3=%lld 大家觉得呢:%lld\n",k,m,_x,ST::query(_x,x),ST::query(1,y),ST::query(_x,y),114);
int ans=inf;
if(_x==k) _x=ST::kth(x,m-1);
chkmin(ans,ST::query(_x,x)+ST::query(1,y)-ST::query(_x,y));
// chkmin(ans,ST::query(_x,x)+ST::query(1,y)-2*ST::query(_x,y));
// if(ST::dis(x,y)&1) _x=f[_x][0],chkmin(ans,ST::query(_x,x)+ST::query(1,y)-2*ST::query(_x,y));
std::cout<<ans<<"\n";
}
// debug("%.11lfms %.8f\n",clock(),(clock()/CLOCKS_PER_SEC)*1e3);
}
/*
North London forever
*/