CF1336F Journey

\(\color{#FF003F}{\texttt {CF1336F Journey}}\)

对两条链的 \(\operatorname {lca}\) 是否相同进行分类讨论。下面 \(x\) 的链指 \(\operatorname {lca}(s,t)=x\) 的链,链 \((s,t)\) 需要满足 \(dfn_s<dfn_t\)

  1. 如果\(\operatorname {lca}\)不同。

dfs整颗树,并在 \(\operatorname {lca}\) 的深度较浅处产生贡献。这样做的好处是 两条链的交一定是 在深度较深的链上的一条以\(\operatorname {lca}\)为端点的链,树状树组统计。
具体地,dfs到一个点 \(x\) 时,先递归处理儿子,再统计 \(x\) 和子树内的链的贡献。
放张官方sol的图。

图中有两条链,E-F和G-H,他们的交是B-G,且B-C和B-D的距离是k。
处理G-H的时候,对C,D做子树加,在E-F统计答案的时候,查询E,F的值。

  1. 如果 \(\operatorname {lca}\) 相同。

枚举每个点作为\(\operatorname {lca}\)。当前点是 \(x\)。将 \(x\) 的所有链的 \(s\) 端点建棵虚树,\(t\) 端点存在 \(s\) 上。
\(\operatorname {lca}(x1,x2)=a\),从 \(a\)\(y1\)\(k\) 步到点 \(u\),所有合法的 \(y2\)\(u\) 的子树内。
启发式合并+线段树合并维护。

图中有两条链,B-E和C-F,他们的交是A-D,且A-X的距离是k。
\(\operatorname {lca}(B,C)=A\),A向E走 \(k\) 步到X,查询X的子树内的 \(y\) 点,点F产生贡献。

  1. 这样做还遗漏了一种情况

图中有2条链,A-E和C-D,他们的交是B-X。
其中 \(dfn_A<dfn_E<dfn_D<dfn_C\),容易发现这对链没被上面两种情况包含。

对于这种情况,把当前点的链按 \(dfn_s\) 排序,此时链的交一定是从lca向下的一条链,用类似第1种情况的方法统计即可。

复杂度 \(O(mlog^2n+nlogn)\),瓶颈在于启发式合并。

// Author -- xyr2005

#include<bits/stdc++.h>

#define lowbit(x) ((x)&(-(x)))
#define DEBUG fprintf(stderr,"Running on Line %d in Function %s\n",__LINE__,__FUNCTION__)
#define SZ(x) ((int)x.size())
#define mkpr std::make_pair
#define pb push_back

typedef long long ll;
typedef unsigned int uint;
typedef unsigned long long ull;
typedef std::pair<int,int> pi;
typedef std::pair<ll,ll> pl;
using std::min;
using std::max;

const int inf=0x3f3f3f3f,Inf=0x7fffffff;
const ll INF=0x3f3f3f3f3f3f3f3f;

std::mt19937 rnd(std::chrono::steady_clock::now().time_since_epoch().count());
template <typename _Tp>_Tp gcd(const _Tp &a,const _Tp &b){return (!b)?a:gcd(b,a%b);}
template <typename _Tp>inline _Tp abs(const _Tp &a){return a>=0?a:-a;}
template <typename _Tp>inline void chmax(_Tp &a,const _Tp &b){(a<b)&&(a=b);}
template <typename _Tp>inline void chmin(_Tp &a,const _Tp &b){(b<a)&&(a=b);}
template <typename _Tp>inline void read(_Tp &x)
{
	char ch(getchar());bool f(false);while(!isdigit(ch)) f|=ch==45,ch=getchar();
	x=ch&15,ch=getchar();while(isdigit(ch)) x=(((x<<2)+x)<<1)+(ch&15),ch=getchar();
	f&&(x=-x);
}
template <typename _Tp,typename... Args>inline void read(_Tp &t,Args &...args){read(t);read(args...);}
inline int read_str(char *s)
{
	char ch(getchar());while(ch==' '||ch=='\r'||ch=='\n') ch=getchar();
	char *tar=s;*tar=ch,ch=getchar();while(ch!=' '&&ch!='\r'&&ch!='\n'&&ch!=EOF) *(++tar)=ch,ch=getchar();
	return tar-s+1;
}

const int N=150005;
int n;
struct edge{
	int v,nxt;
}c[N<<1];
int front[N],edge_cnt;
inline void addedge(int u,int v)
{
	c[++edge_cnt]=(edge){v,front[u]};
	front[u]=edge_cnt;
}
int anc[N][21],dep[N],siz[N],dfn[N],rev[N],id;
struct seg_tr{
	struct Node{
		int ls,rs,sum;
	}f[N<<5];
	int node_cnt;
	int st[N<<5],top;
	inline void PushUp(int x){f[x].sum=f[f[x].ls].sum+f[f[x].rs].sum;}
	inline int newnode()
	{
		int cur=top?st[top--]:++node_cnt;
		f[cur]=(Node){0,0,0};
		return cur;
	}
	void Update(int &cur,int l,int r,int pos)
	{
		if(!cur) cur=newnode();
		++f[cur].sum;
		if(l==r) return;
		int mid=(l+r)>>1;
		if(pos<=mid) Update(f[cur].ls,l,mid,pos);
		else Update(f[cur].rs,mid+1,r,pos);
	}
	int Query(int L,int R,int l,int r,int cur)
	{
		if(!cur) return 0;
		if(L<=l&&r<=R) return f[cur].sum;
		int mid=(l+r)>>1;
		return (L<=mid?Query(L,R,l,mid,f[cur].ls):0)+(R>mid?Query(L,R,mid+1,r,f[cur].rs):0);
	}
	int merge(int a,int &b)
	{
		if(!a||!b) return a|b;
		f[a].sum+=f[b].sum;
		f[a].ls=merge(f[a].ls,f[b].ls);
		f[a].rs=merge(f[a].rs,f[b].rs);
		st[++top]=b,b=0;
		return a;
	}
	void del(int &x)
	{
		if(!x) return;
		del(f[x].ls),del(f[x].rs);
		st[++top]=x,x=0;
	}
}tr;
int Fa[N];
void dfs1(int x,int fa)
{
	dep[x]=dep[fa]+1,anc[x][0]=fa,Fa[x]=fa,siz[x]=1;
	for(int i=1;i<=20;++i) anc[x][i]=anc[anc[x][i-1]][i-1];
	dfn[x]=++id,rev[x]=id;
	for(int i=front[x];i;i=c[i].nxt)
	{
		int v=c[i].v;
		if(v!=fa) dfs1(v,x),siz[x]+=siz[v];
	}
}
int jump(int x,int k)
{
	for(int i=20;i>=0;--i) if((k>>i)&1)	x=anc[x][i];
	return x;
}
int lca(int x,int y)
{
	if(dep[x]<dep[y]) std::swap(x,y);
	for(int i=20;i>=0;--i) if(dep[anc[x][i]]>=dep[y]) x=anc[x][i];
	if(x==y) return x;
	for(int i=20;i>=0;--i) if(anc[x][i]!=anc[y][i]) x=anc[x][i],y=anc[y][i];
	return anc[x][0];
}
ll ans;
int k;
struct BIT{
	int c[N];
	inline void clear(){memset(c,0,sizeof(c));}
	inline void add(int x,int C){++x;for(;x<N;x+=lowbit(x))c[x]+=C;}
	inline int sum(int x){++x;int ans=0;for(;x;x-=lowbit(x))ans+=c[x];return ans;}
}_tr;
struct node{
	int x,y;
	inline bool operator < (const node &o)const{return dfn[x]<dfn[o.x];}
};
std::vector<node> v[N];
void dfs2(int x,int fa)
{
	for(int i=front[x];i;i=c[i].nxt)
	{
		int v=c[i].v;
		if(v!=fa) dfs2(v,x);
	}
	for(auto it:v[x]) ans+=_tr.sum(dfn[it.x])+_tr.sum(dfn[it.y]);
	for(auto it:v[x])
	{
		if(dep[it.x]-dep[x]>=k)
		{
			int qwq=jump(it.x,dep[it.x]-dep[x]-k);
			_tr.add(dfn[qwq],1),_tr.add(dfn[qwq]+siz[qwq],-1);
		}
		if(dep[it.y]-dep[x]>=k)
		{
			int qwq=jump(it.y,dep[it.y]-dep[x]-k);
			_tr.add(dfn[qwq],1),_tr.add(dfn[qwq]+siz[qwq],-1);
		}
	}
}
int t[N],pos,st[N],top;
std::vector<int> e[N],q[N];
int root[N];
void ins(int x)
{
	if(!top||(dfn[x]>=dfn[st[top]]&&dfn[x]<dfn[st[top]]+siz[st[top]]))
	{
		t[++pos]=x,st[++top]=x;
		return;
	}
	int l=lca(x,st[top]);
	while(top>1&&dfn[st[top-1]]>=dfn[l]) e[st[top-1]].pb(st[top]),--top;
	if(st[top]!=l) e[l].push_back(st[top]),st[top]=l,t[++pos]=l;
	st[++top]=x,t[++pos]=x;
}
int cur_node;
std::vector<int> in[N];
void dfs3(int x)
{
	std::function<void(int)> merge=[&](int a)
	{
		if(in[a].size()>in[x].size()) std::swap(in[a],in[x]),std::swap(root[a],root[x]);
		for(auto it:in[a])
		{
			int qwq;
			if(dep[x]-dep[cur_node]>=k) qwq=cur_node;
			else
			{
				int len=dep[x]+dep[it]-(dep[cur_node]<<1);
				if(len<k) continue;
				qwq=jump(it,len-k);
			}
			ans+=tr.Query(dfn[qwq],dfn[qwq]+siz[qwq]-1,1,n,root[x]);
		}
		for(auto it:in[a]) in[x].push_back(it);
		in[a].clear();
		root[x]=tr.merge(root[x],root[a]);
	};
	for(auto it:q[x])
	{
		int qwq;
		if(dep[x]-dep[cur_node]>=k) qwq=cur_node;
		else
		{
			int len=dep[x]+dep[it]-(dep[cur_node]<<1);
			if(len>=k) qwq=jump(it,len-k);
			else qwq=0;
		}
		if(qwq) ans+=tr.Query(dfn[qwq],dfn[qwq]+siz[qwq]-1,1,n,root[x]);
		tr.Update(root[x],1,n,dfn[it]);
		in[x].push_back(it);
	}
	for(auto it:e[x]) dfs3(it),merge(it);
}
void solve(int x)
{
	cur_node=x,top=0,pos=0;
	std::vector<int> nd;
	for(auto it:v[x]) nd.pb(it.x),q[it.x].pb(it.y);
	nd.erase(std::unique(nd.begin(),nd.end()),nd.end());
	for(auto it:nd) ins(it);
	while(top>1) e[st[top-1]].pb(st[top]),--top;
	int minn=inf,id=0;
	for(int i=1;i<=pos;++i) if(dfn[t[i]]<minn) minn=dfn[t[i]],id=t[i];
	if(id) dfs3(id);
	for(auto it:v[x]) q[it.x].clear();
	for(int i=1;i<=pos;++i) tr.del(root[t[i]]),in[t[i]].clear(),e[t[i]].clear();
}
void _solve(int x)
{
	std::vector<int> tmp;
	for(auto it:v[x])
	{
		ans+=_tr.sum(dfn[it.x]);
		if(dep[it.y]-dep[x]>=k)
		{
			int qwq=jump(it.y,dep[it.y]-dep[x]-k);
			_tr.add(dfn[qwq],1),_tr.add(dfn[qwq]+siz[qwq],-1);
			tmp.push_back(qwq);
		}
	}
	for(auto qwq:tmp) _tr.add(dfn[qwq],-1),_tr.add(dfn[qwq]+siz[qwq],1);
}
int main()
{
	int m;read(n,m,k);
	int x,y;
	for(int i=1;i<n;++i) read(x,y),addedge(x,y),addedge(y,x);
	dfs1(1,0);
	for(int i=1;i<=m;++i)
	{
		read(x,y);
		if(dfn[x]>dfn[y]) std::swap(x,y);
		v[lca(x,y)].pb((node){x,y});
	}
	dfs2(1,0);
	_tr.clear();
	for(int i=1;i<=n;++i) std::sort(v[i].begin(),v[i].end());
	for(int i=1;i<=n;++i) _solve(i);
	for(int i=1;i<=n;++i) solve(i);
	printf("%lld\n",ans);
	return 0;
}

posted @ 2020-04-17 13:00  xyr2005  阅读(796)  评论(1编辑  收藏  举报