洛谷P5327/LOJ3046/UOJ470/BZOJ5518[ZJOI2019]语言(树链剖分+差分+线段树合并)
对于每个结点$u$,考虑维护它能到达的结点(不包括$u$本身)集合$S_u$,那么答案就是$\frac{\sum\limits^n_{i=1}{|S_u|}}{2}$。
然后考虑这个$S_u$怎么搞。暴力
先大力树剖一遍,那么$(u,v)$的路径就变成了不超过$log_2n$个区间,如果把每一个结点的$S$都用线段树维护的话,就可以通过把$(u,v)$路径上所有点跑$log_2n$个区间覆盖就可以做到$O(n^2log^2n)$说白了就是个暴力
等等,对一条路径上的所有点做相同的修改?
树上差分+线段树合并!
于是就得到了一个$O(nlog^2n)$的方法,足以通过本题。
码的时候注意一些细节:
1. 不包含$u$的集合比较难维护,我们先钦定包含$u$,统计的时候再减掉。
2. 因为要差分,所以区间覆盖要改成区间加,不过由于合并完成后的标记肯定都是正的,所以只要分有没有标记即可
3. 这个写法严格来讲是点差分,但是因为我们只关心有没有标记,边差分也是正确的,常数还小一点
4. 我们可以把每个结点要做的修改存下来,每次先合并完再修改,尽量节约空间
#include<cstdio> #include<vector> using namespace std; const int N=100050; char rB[1<<21],*S,*T; inline char gc(){return S==T&&(T=(S=rB)+fread(rB,1,1<<21,stdin),S==T)?EOF:*S++;} inline int rd(){ char c=gc(); while(c<48||c>57)c=gc(); int x=c&15; for(c=gc();c>=48&&c<=57;c=gc())x=(x<<3)+(x<<1)+(c&15); return x; } struct seg{ int x,y; seg(int x,int y):x(x),y(y){} }; struct query{ int x,y; short k; query(int x,int y,short k):x(x),y(y),k(k){} }; int G[N],to[N<<1],nxt[N<<1],cnt1=0,f[N],dep[N],sz[N],son[N],id[N],top[N],dfsc=0,rt[N],lc[N<<8],rc[N<<8],addv[N<<8],val[N<<8],cnt2=0,n; long long ans=0; vector<seg> a; vector<query> w[N]; inline void Swap(int &a,int &b){int t=a;a=b;b=t;} inline void adde(int u,int v){ to[++cnt1]=v;nxt[cnt1]=G[u];G[u]=cnt1; to[++cnt1]=u;nxt[cnt1]=G[v];G[v]=cnt1; } inline void pushup(int o,int L,int R){ val[o]=addv[o]>0?R-L+1:val[lc[o]]+val[rc[o]]; //维护权值,只关心有没有标记 } void add(int &o,int L,int R,int x,int y,short k){ //区间加 if(!o)o=++cnt2; if(x<=L&&y>=R)addv[o]+=k; else{ int M=L+R>>1; if(x<=M)add(lc[o],L,M,x,y,k); if(y>M)add(rc[o],M+1,R,x,y,k); } pushup(o,L,R); } int merge(int u,int v,int L,int R){ //线段树合并 if(!u||!v)return u|v; addv[u]+=addv[v]; if(L<R){ int M=L+R>>1; lc[u]=merge(lc[u],lc[v],L,M); rc[u]=merge(rc[u],rc[v],M+1,R); } pushup(u,L,R); return u; } void dfs1(int u,int fa){ int i,v,maxn=0; dep[u]=dep[f[u]=fa]+1; sz[u]=1; for(i=G[u];i;i=nxt[i])if((v=to[i])!=fa){ dfs1(v,u); sz[u]+=sz[v]; if(sz[v]>maxn){ son[u]=v; maxn=sz[v]; } } } void dfs2(int u,int topf){ int i,v; id[u]=++dfsc; top[u]=topf; w[u].push_back(query(id[u],id[u],1)); //先给u打标记,钦定它必须选 if(f[u])w[f[u]].push_back(query(id[u],id[u],-1)); //这里有个坑,树上差分的单点修改记得在父结点减回去 if(!son[u])return; dfs2(son[u],topf); for(i=G[u];i;i=nxt[i])if((v=to[i])!=f[u]&&v!=son[u])dfs2(v,v); } inline int lca(int u,int v){ a.clear(); while(top[u]!=top[v]){ if(dep[top[u]]<dep[top[v]])Swap(u,v); a.push_back(seg(id[top[u]],id[u])); //把路径分离成区间 u=f[top[u]]; } if(dep[u]>dep[v])Swap(u,v); a.push_back(seg(id[u],id[v])); return u; } //树剖部分 void solve(int u){ int i,v; for(i=G[u];i;i=nxt[i])if((v=to[i])!=f[u]){ solve(v); rt[u]=merge(rt[u],rt[v],1,n); } //合并 for(i=0;i<w[u].size();++i)add(rt[u],1,n,w[u][i].x,w[u][i].y,w[u][i].k); //修改 ans+=val[rt[u]]-1; //由于我们钦定了u必须选,统计时要减掉 } int main(){ int m,i,u,v,k; n=rd();m=rd(); for(i=1;i<n;++i){ u=rd();v=rd(); adde(u,v); } dfs1(1,0); dfs2(1,1); while(m--){ u=rd();v=rd(); k=lca(u,v); for(i=0;i<a.size();++i){ w[u].push_back(query(a[i].x,a[i].y,1)); w[v].push_back(query(a[i].x,a[i].y,1)); if(f[k])w[f[k]].push_back(query(a[i].x,a[i].y,-2)); } } solve(1); printf("%lld",ans>>1); return 0; }