YbtOJ-交换游戏【树链剖分,线段树合并】
正题
题目大意
给出两棵树,对于第一棵树的每一条边\((x,y)\)询问有多少条在第二棵树上的边\((u,v)\)与其交换(连接的序号相同)后两棵树依旧是一棵树。
\(1\leq n\leq 2\times 10^5\)
解题思路
先只考虑一棵树的合法情况,对于第二棵树的边\((u,v)\)交换过来合法的当且仅当\((x,y)\)在\(u\rightarrow v\)路径上,同理的对于第二棵树合法当且仅当\((u,v)\)在\(x\rightarrow y\)路径上。
那么考虑限制一个条件,第二个条件用数据结构查询。
我们把所有的\((u,v)\)用树上差分挂在第一棵树的\(u\rightarrow v\)路径上,然后遇到一条\((u,v)\)我们让这棵树的这条边权值\(+1\)。
这样我们就保证了处理到边\((x,y)\)时有权值的只有在第一棵树上经过\((x,y)\)的\(u\rightarrow v\),那么至于第二个要求我们直接在第二棵树上查询\(x\rightarrow y\)路径上的权值和。这个可以用树链剖分维护。
而树上差分的合并功能就用线段树合并就好了。
时间复杂度:\(O(n\log^2n)\)
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define mp(x,y) make_pair(x,y)
using namespace std;
const int N=2e5+10,U=20;
int n,cnt,depG[N],f[N][U],rt[N],ans[N];
int siz[N],fa[N],dep[N],son[N],top[N],id[N];
vector<pair<int,int> > G[N];
vector<int> T[N],v[N];
pair<int,int> e[N];
struct SegTree{
int w[N<<5],ls[N<<5],rs[N<<5];
void Change(int &x,int L,int R,int pos,int val){
if(!x)x=++cnt;w[x]+=val;
if(L==R)return;int mid=(L+R)>>1;
if(pos<=mid)Change(ls[x],L,mid,pos,val);
else Change(rs[x],mid+1,R,pos,val);
}
int Ask(int x,int L,int R,int l,int r){
if(!x)return 0;
if(L==l&&R==r)return w[x];
int mid=(L+R)>>1;
if(r<=mid)return Ask(ls[x],L,mid,l,r);
if(l>mid)return Ask(rs[x],mid+1,R,l,r);
return Ask(ls[x],L,mid,l,mid)+Ask(rs[x],mid+1,R,mid+1,r);
}
int Merge(int x,int y,int L,int R){
if(!x||!y)return x|y;w[x]+=w[y];
if(L==R)return x;
int mid=(L+R)>>1;
ls[x]=Merge(ls[x],ls[y],L,mid);
rs[x]=Merge(rs[x],rs[y],mid+1,R);
return x;
}
}S;
void preset(int x,int fa){
f[x][0]=fa;depG[x]=depG[fa]+1;
for(int i=0;i<G[x].size();i++){
int y=G[x][i].first;
if(y==fa)continue;
preset(y,x);
}
return;
}
int LCA(int x,int y){
if(depG[x]<depG[y])swap(x,y);
for(int i=U-1;i>=0;i--)
if(depG[f[x][i]]>=depG[y])x=f[x][i];
if(x==y)return x;
for(int i=U-1;i>=0;i--)
if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
return f[x][0];
}
void dfs1(int x){
dep[x]=dep[fa[x]]+1;siz[x]=1;
for(int i=0;i<T[x].size();i++){
int y=T[x][i];
if(y==fa[x])continue;
fa[y]=x;dfs1(y);
siz[x]+=siz[y];
if(siz[y]>siz[son[x]])
son[x]=y;
}
return;
}
void dfs2(int x){
id[x]=++cnt;
if(son[x]){
top[son[x]]=top[x];
dfs2(son[x]);
}
for(int i=0;i<T[x].size();i++){
int y=T[x][i];
if(y==fa[x]||y==son[x])continue;
top[y]=y;dfs2(y);
}
return;
}
int GetAns(int x,int y,int rt){
int ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
ans+=S.Ask(rt,1,n,id[top[x]],id[x]);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
if(x!=y)ans+=S.Ask(rt,1,n,id[x]+1,id[y]);
return ans;
}
void solve(int x,int fa,int ids){
for(int i=0;i<G[x].size();i++){
int y=G[x][i].first,id=G[x][i].second;
if(y==fa)continue;
solve(y,x,id);
rt[x]=S.Merge(rt[x],rt[y],1,n);
}
if(ids){
for(int i=0;i<v[x].size();i++){
int p=abs(v[x][i]);
int X=e[p].first,Y=e[p].second;
if(dep[X]>dep[Y])swap(X,Y);
S.Change(rt[x],1,n,id[Y],(v[x][i]>0)?1:-2);
}
ans[ids]=GetAns(fa,x,rt[x]);
}
return;
}
int main()
{
freopen("exchange.in","r",stdin);
freopen("exchange.out","w",stdout);
scanf("%d",&n);
for(int i=1;i<n;i++){
int x,y;
scanf("%d%d",&x,&y);
G[x].push_back(mp(y,i));
G[y].push_back(mp(x,i));
}
for(int i=1;i<n;i++){
int x,y;
scanf("%d%d",&x,&y);
T[x].push_back(y);
T[y].push_back(x);
e[i]=mp(x,y);
}
preset(1,0);
for(int j=1;j<U;j++)
for(int i=1;i<=n;i++)
f[i][j]=f[f[i][j-1]][j-1];
for(int i=1;i<n;i++){
int x=e[i].first,y=e[i].second,lca=LCA(x,y);
v[x].push_back(i);v[y].push_back(i);
v[lca].push_back(-i);
}
dfs1(1);top[1]=1;dfs2(1);
solve(1,0,0);
for(int i=1;i<n;i++)
printf("%d ",ans[i]);
return 0;
}