【洛谷5327】[ZJOI2019] 语言(差分+线段树合并)
大致题意: 给定一棵树,选定树上若干条链,问有多少对点在同一条链上。
前言
翻了翻过去的游记,发现去年的我遇到这题写的是\(O(n^3)\)的暴搜,然后加上了一个奇奇怪怪的剪枝(我不记得是什么了,而且现在也想不出来),就变成了\(O(n^2)\)的复杂度。
回想起来,去年的我做题真的完全都在靠搜索。不光此题,麻将那道题也是加了一个奇奇怪怪的剪枝(同样不记得究竟是啥了),估分\(20\)分的暴搜莫名其妙水到了\(50\)分。
然而最近的我写搜索题却屡屡爆炸,看来我已经失去了我做题最重要的武器和最根本的凭依,真是越来越菜了。
尽管就这题而论,现在的我已经能轻松想到\(O(n^2)\)的做法,仔细想想也能想出当初被点名卡掉的\(3\)只\(log\)的做法,但其他题目呢?不由得为之后的命运感到恐慌。
从一个点出发的答案
考虑枚举每一个点\(x\),去计算从\(x\)出发的答案。
容易发现,这个答案就是\(x\)只走某一条链能到达的点数(废话,题目里要求的就是这个),那么也就是所有经过\(x\)(这样\(x\)才能走这条链)的链的并集大小
考虑计算链的并集,一个暴力的做法就是枚举每条经过\(x\)的链,利用树上差分给这条链上的点权值加\(1\),最后遍历一遍整棵树,统计有多少权值大于\(0\)的点,就可以计算出\(x\)的答案。于是我们就得到了一个朴素的\(O(n^2)\)的做法,可以拿到\(40\)分。(实际上这个做法和正解无关,就不细讲了)
上面这个暴力实际上并不好优化,因此我们需要考虑另一种暴力。
我们把一条链\(u,v\)拆成分别从\(u,v\)到根节点的两条链,则它们的并集就相当于在原先基础上多出了从\(LCA(u,v)\)到根节点的一段。
而如果我们把经过\(x\)的所有链都拆成这样两条链,那么它们的并集就相当于在原先基础上多出了所有点的\(LCA\)到根节点的一段。
这样一转化显然就简单了许多,且所有点\(LCA\)到根的一段的非法贡献也是很容易计算并减去的。
然后我们只需考虑如何求出若干条到根节点的链的并集大小。
这就有一个套路的做法:我们把所有点按\(dfs\)序进行排序,则最后的答案就是所有点的深度减去相邻两点\(LCA\)的深度。
顺便提一下,所有点的\(LCA\)其实就是\(dfs\)序最小的点和\(dfs\)序最大的点的\(LCA\),也即所有点\(LCA\)到根的一段的非法贡献就是\(dfs\)序最小的点和\(dfs\)序最大的点的\(LCA\)的深度。
(注意,本来非法贡献应该是深度\(-1\),因为\(LCA\)是计算在答案内的,但由于当前点不应该计算在答案内,所以就在这里一并减去了)。
还有,如果此题想要有较为优秀的复杂度,需要使用\(O(1)\)的\(LCA\),这里就不予介绍了。
所有点的答案
上面讲了许多,但都是在说如何计算一个点的答案,且复杂度和暴力相比没有起到任何优化。
考虑如果对每个点都去枚举所有经过它的链再计算答案,显然复杂度会爆炸。
这里我们先在刚刚的方法上做一个小小的修改,本来我们需要按\(dfs\)序排序,现在我们把每个点插入一个以\(dfs\)序为下标的线段树中。(果然线段树这东西什么都能做)
则每个点的深度依旧是每个点的深度,而要减去相邻两点\(LCA\)的深度在线段树上应该是很好实现的:只要记下每个区间最左和最右的点,合并区间时减去左区间最右的点和右区间最左的点的\(LCA\)深度即可。
然后我们回到刚刚的话题,要求出经过每个点的链,就是对于一条链要给所有它经过的点打上标记,容易想到树上差分(即对于路径\(u,v\),给\(u,v\)加\(1\),给\(LCA(u,v),fa(LCA(u,v))\)减\(1\))。
在此题中,打标记这一操作具体就是指在线段树中插入\(u,v\)。
而树上差分有一个关键步骤就是统计子树标记和,那么刚刚搞出来的线段树就发挥作用了:只要通过线段树合并,就能很方便地统计标记。
这样一来,这道恶心的题目总算是被我们做出来了。
代码
#include<bits/stdc++.h>
#define Tp template<typename Ty>
#define Ts template<typename Ty,typename... Ar>
#define Reg register
#define RI Reg int
#define Con const
#define CI Con int&
#define I inline
#define W while
#define N 100000
#define LN 20
#define LL long long
#define add(x,y) (e[++ee].nxt=lnk[x],e[lnk[x]=ee].to=y)
#define swap(x,y) (x^=y^=x^=y)
using namespace std;
int n,ee,lnk[N+5];struct edge {int to,nxt;}e[N<<1];
class FastIO
{
private:
#define FS 100000
#define tc() (A==B&&(B=(A=FI)+fread(FI,1,FS,stdin),A==B)?EOF:*A++)
#define D isdigit(c=tc())
char c,*A,*B,FI[FS];
public:
I FastIO() {A=B=FI;}
Tp I void read(Ty& x) {x=0;W(!D);W(x=(x<<3)+(x<<1)+(c&15),D);}
#undef D
}F;
class EulerLCA//O(1)LCA
{
private:
#define Shallower(x,y) (P[x]<P[y]?(x):(y))
int D[N+5],P[N+5],Lg[2*N+5],Mn[2*N+5][LN+5];
I void dfs(CI x,CI lst=0)
{
RI i;for(Mn[D[x]=++d][0]=x,i=lnk[x];i;i=e[i].nxt) e[i].to^lst&&
(P[e[i].to]=P[f[e[i].to]=x]+1,dfs(e[i].to,x),Mn[++d][0]=x);//括号序列
}
public:
I int operator [] (CI x) {return D[x];}I int operator () (CI x) {return P[x];}
int d,f[N+5];I void Init()
{
RI i,j;for(P[1]=1,dfs(1),Lg[0]=-1,i=1;i<=d;++i) Lg[i]=Lg[i>>1]+1;
for(j=1;(1<<j)<=d;++j) for(i=1;i+(1<<j)-1<=d;++i)//预处理RMQ
Mn[i][j]=Shallower(Mn[i][j-1],Mn[i+(1<<j-1)][j-1]);
}
I int LCA(RI x,RI y)//RMQ求LCA
{
if(!x||!y) return 0;(x=D[x])>(y=D[y])&&swap(x,y);//转化为dfs序,保证左端点小于右端点
RI k=Lg[y-x+1];return Shallower(Mn[x][k],Mn[y-(1<<k)+1][k]);//RMQ
}
}E;
class SegmentTree//线段树
{
private:
#define PU(x)\
(\
O[x].G=O[O[x].S[0]].G+O[O[x].S[1]].G-E(E.LCA(O[O[x].S[0]].R,O[O[x].S[1]].L)),\
O[x].L=O[O[x].S[0]].L?O[O[x].S[0]].L:O[O[x].S[1]].L,\
O[x].R=O[O[x].S[1]].R?O[O[x].S[1]].R:O[O[x].S[0]].R\
)//上传信息,注意减去相邻点LCA的深度
int Nt;struct node {int T,L,R,G,S[2];}O[N*LN<<2];
public:
I void U(int& rt,CI p,CI v,CI l=1,CI r=E.d)//单点修改
{
if(!rt&&(rt=++Nt),l==r) return (void)//对于叶节点
((O[rt].T+=v)?(O[rt].G=E(p),O[rt].L=O[rt].R=p):(O[rt].G=O[rt].L=O[rt].R=0));//判断该点有无贡献
RI mid=l+r>>1;E[p]<=mid?U(O[rt].S[0],p,v,l,mid):U(O[rt].S[1],p,v,mid+1,r),PU(rt);
}
I void Merge(int& x,CI y,CI l=1,CI r=E.d)//线段树合并
{
if(!x||!y) return (void)(x|=y);RI mid=l+r>>1;
if(l==r) return (void)(O[x].T+=O[y].T,O[x].G|=O[y].G,O[x].L|=O[y].L,O[x].R|=O[y].R);//注意"|"
Merge(O[x].S[0],O[y].S[0],l,mid),Merge(O[x].S[1],O[y].S[1],mid+1,r),PU(x);
}
I int Q(CI rt) {return O[rt].G-E(E.LCA(O[rt].L,O[rt].R));}//减去dfs序最小和dfs序最大的点LCA的深度
}S;
int Rt[N+5];LL ans;vector<int> w[N+5];
I void Ins(CI x,CI y)//树上差分处理一条路径
{
#define pb push_back
S.U(Rt[x],x,1),S.U(Rt[x],y,1),S.U(Rt[y],x,1),S.U(Rt[y],y,1);RI z=E.LCA(x,y);//在x,y的线段树中插入
w[z].pb(x),w[z].pb(y),E.f[z]&&(w[E.f[z]].pb(x),w[E.f[z]].pb(y),0);//注意这里如果直接在线段树上修改会出现负数,故先开vector存标记
}
I void Calc(CI x=1,CI lst=0)//统计答案
{
RI i;for(i=lnk[x];i;i=e[i].nxt) e[i].to^lst&&(Calc(e[i].to,x),S.Merge(Rt[x],Rt[e[i].to]),0);//线段树合并统计子树标记
RI s=w[x].size();for(i=0;i^s;++i) S.U(Rt[x],w[x][i],-1);ans+=S.Q(Rt[x]);//处理删除操作,然后累加答案
}
int main()
{
RI Qt,i,x,y;for(F.read(n),F.read(Qt),i=1;i^n;++i) F.read(x),F.read(y),add(x,y),add(y,x);//读边建树
E.Init();W(Qt--) F.read(x),F.read(y),Ins(x,y);return Calc(),printf("%lld\n",ans>>1),0;//注意答案除以2
}