【XSY3490】线段树(广义线段树,树上莫队)

题面

线段树

题解

本题分两 Part 走。

Part 1

我们需要解决: 如何在广义线段树上快速区间定位节点。

对于有 \(n\) 个叶子节点、共 \(2n-1\) 个节点的广义线段树 \(A\),我们定义 \(maxr_l\) 表示 \(A\) 中所有以 \(l\) 为左端点的区间的右端点的最大值,\(minl_r\)\(A\) 中所有以 \(r\) 为右端点的区间的左端点的最小值。然后再定义 \(L\) 树和 \(R\) 树,\(L\) 树中点 \(l\) 对应 \(A\) 树中的点 \([l,maxr_l]\)\(l\) 的父亲为 \(maxr_l+1\)\(R\) 树中 \(r\) 对应 \(A\) 树中的点 \([minl_r,r]\)\(r\) 的父亲为 \(minl_r-1\)

区间定位 \([l,r]\) 的时候,我们在 \(L\) 树上从 \(l\) 开始往上跳直到当前点 \(u\)\(maxr_u\) 恰好大于 \(r\) 为止,\(R\) 树上从 \(r\) 开始往上跳直到当前点 \(u\)\(minl_u\) 刚好小于 \(l\) 为止,这两段所对应 \(A\) 树的节点就是我们区间定位得到的节点。

考虑这样做为什么是对的,我们称 \(A\) 树上一个节点为 “右儿子节点” 当且仅当它是其父亲的右儿子,“左儿子节点” 的定义同理。假设在 \(A\) 树上区间定位 \([L,R]\) 所得到的的节点为 \(a_1,\cdots,a_m\),那么肯定存在一个 \(k\),使得 \(a_1,\cdots,a_k\) 都是右儿子节点,\(a_{k+1},\cdots,a_m\) 都是左儿子节点。

假设现在有一个 \(A\) 树上的节点,它的左端点 \(l\) 已知,而且已知它是右儿子节点,那么我们能确定这个点吗?当然能,因为以某个 \(l\) 为左端点的右儿子节点是唯一的,就是 \([l,maxr_l]\)

再回到区间定位,我们已经知道了 \(a_1\) 的左端点一定是 \(L\),那么若 \(a_1\) 是右儿子节点,\(a_1\) 是被唯一确定的,我们先找到它,设为 \(p\)。若它超出 \([L,R]\) 范围则 \(a_1\) 不可能是区间定位中的点,之后也不可能出现右儿子节点了,直接 break。否则,发现若 \(p\) 所对应的区间在 \([L,R]\) 内,则它一定是 \([L,R]\) 区间定位中的节点(即它一定是 \(a_1\)),因为其祖先所对应的区间一定会超出 \([L,R]\)。然后也就知道了 \(a_2\) 的左端点,于是递归处理,直到当前找到的节点超出 \([L,R]\) 范围为止。这样我们就能找到 \(a_1,\cdots,a_k\)。这个过程就是在 \(L\) 树上不断跳祖先的过程。

同样地,我们也能用类似的方法找到 \(a_{k+1},\cdots,a_m\)。这个过程就是在 \(R\) 树上不断跳祖先的过程。

至此,我们就把广义线段树上一段区间定位得到的节点转化为 \(L\) 树和 \(R\) 树上的两条祖先-后代链上对应的节点。

Part 2

现在是要求 \(L/R\) 树上一条链在 \(B\) 树中对应的点,到 \(L/R\) 树上一条链在 \(B\) 树中对应的点,的两两距离和。这里以 \(L\) 树到 \(R\) 树为例。

有一个单次询问复杂度关于点数成线性的虚树做法,但显然过不了。

做法是利用链的性质:差分之后转化为 \(L\) 树上一条到根的链和 \(R\) 树上一条到根的链在 \(B\) 树中对应的点的两两距离和,然后直接树上莫队。

树上莫队的具体做法是:两个指针分别在 \(L\) 树和 \(R\) 树的欧拉序上移动。当加入/删除一个点时,我们要求出它到另一棵树上一条到根的链在 \(B\) 树中对应的点的距离和,使用点分治即可。

时间复杂度 \(O(n\sqrt m\log n)\)

#include<bits/stdc++.h>

#define fi first
#define se second
#define pii pair<int,int>
#define mk(a,b) make_pair(a,b)
#define INF 0x7fffffff
#define ll long long
#define LN 14
#define N 10010
#define M 100010

using namespace std;

inline int read()
{
	int x=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9')
	{
		if(ch=='-') f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		x=(x<<1)+(x<<3)+(ch^'0');
		ch=getchar();
	}
	return x*f;
}

int n,m;

namespace Btree
{
	const int V=N<<1;
	int nn;
	int cnt,head[V],to[V<<1],nxt[V<<1];
	int ns,maxn,rt,fa[V],size[V];
	bool vis[V];
	vector<int>Dis[V][2];
	void adde(int u,int v)
	{
		to[++cnt]=v;
		nxt[cnt]=head[u];
		head[u]=cnt;
	}
	void getsize(int u,int fa)
	{
		size[u]=1;
		for(int i=head[u];i;i=nxt[i])
		{
			int v=to[i];
			if(vis[v]||v==fa) continue;
			getsize(v,u);
			size[u]+=size[v];
		}
	}
	void getroot(int u,int fa)
	{
		int nmax=0;
		for(int i=head[u];i;i=nxt[i])
		{
			int v=to[i];
			if(vis[v]||v==fa) continue;
			getroot(v,u);
			nmax=max(nmax,size[v]);
		}
		nmax=max(nmax,ns-size[u]);
		if(nmax<maxn) maxn=nmax,rt=u;
	}
	void getdis(int u,int fa,int dis,bool opt)
	{
		Dis[u][opt].push_back(dis);
		for(int i=head[u];i;i=nxt[i])
		{
			int v=to[i];
			if(vis[v]||v==fa) continue;
			getdis(v,u,dis+1,opt);
		}
	}
	void solve(int u)
	{
		vis[u]=1;
		getdis(u,0,0,0);
		for(int i=head[u];i;i=nxt[i])
		{
			int v=to[i];
			if(vis[v]) continue;
			getsize(v,0);
			ns=size[v],maxn=INF,getroot(v,0);
			getdis(v,0,1,1);
			fa[rt]=u,solve(rt);
		}
	}
	void init()
	{
		nn=(n<<1)-1;
		for(int i=1;i<nn;i++)
		{
			int u=read(),v=read();
			adde(u,v),adde(v,u);
		}
		getsize(1,0);
		ns=size[1],maxn=INF,getroot(1,0);
		solve(rt);
	}
	struct Dtree
	{
		int sdis[V][2],snum[V][2];
		void insert(int u,int coef)
		{
			int now=u;
			for(int i=(int)Dis[u][0].size()-1;i>=0;i--)
			{
				sdis[now][0]+=coef*Dis[u][0][i],snum[now][0]+=coef;
				if(i) sdis[now][1]+=coef*Dis[u][1][i-1],snum[now][1]+=coef;
				now=fa[now];
			}
		}
		int query(int u)
		{
			int now=u,ans=0;
			for(int i=(int)Dis[u][0].size()-1;i>=0;i--)
			{
				ans+=sdis[now][0]+snum[now][0]*Dis[u][0][i];
				if(i) ans-=sdis[now][1]+snum[now][1]*Dis[u][1][i-1];
				now=fa[now];
			}
			return ans;
		}
	}t1,t2;
}

namespace Atree
{
	struct Tree
	{
		int rt;
		int cnt,head[N],pid[N],nxt[N<<1],to[N<<1]; 
		int fa[N][LN],d[N];
		int idx,in[N],out[N];
		pii p[N<<1];
		void adde(int u,int v,int pp)
		{
			pid[v]=pp;
			to[++cnt]=v;
			nxt[cnt]=head[u];
			head[u]=cnt;
		}
		void dfs(int u)
		{
			if(u!=rt) p[in[u]=++idx]=mk(pid[u],1);
			for(int i=1;i<=13&&fa[u][i-1]!=-1;i++)
				fa[u][i]=fa[fa[u][i-1]][i-1];
			for(int i=head[u];i;i=nxt[i])
			{
				int v=to[i];
				fa[v][0]=u,d[v]=d[u]+1;
				dfs(v);
			}
			if(u!=rt) p[out[u]=++idx]=mk(pid[u],-1);
		}
		void init()
		{
			memset(fa,-1,sizeof(fa));
			d[rt]=1,dfs(rt);
		}
	}L,R;
	int node;
	pii maxr[N],minl[N];
	void dfs(int l,int r)
	{
		int u=++node;
		if(l!=r)
		{
			int mid=read();
			dfs(l,mid),dfs(mid+1,r);
		}
		maxr[l]=mk(r,u),minl[r]=mk(l,u);
	}
	void init()
	{
		dfs(1,n);
		for(int i=1;i<=n;i++)
		{
			L.adde(maxr[i].fi+1,i,maxr[i].se);
			R.adde(minl[i].fi-1,i,minl[i].se);
		}
		L.rt=n+1,R.rt=0,L.init(),R.init();
	}
	void find(int l,int r,pii &Lp,pii &Rp)
	{
		if(maxr[l].fi>r) Lp=mk(1,0);
		else
		{
			int a=l;
			for(int i=13;i>=0;i--)
				if(L.fa[a][i]!=-1&&L.fa[a][i]!=L.rt&&maxr[L.fa[a][i]].fi<=r)
					a=L.fa[a][i];
			Lp=mk(L.in[a],L.in[l]);
		}
		if(minl[r].fi<l||(l==1&&r==n)) Rp=mk(1,0);
		else
		{
			int a=r;
			for(int i=13;i>=0;i--)
				if(R.fa[a][i]!=-1&&R.fa[a][i]!=R.rt&&minl[R.fa[a][i]].fi>=l)
					a=R.fa[a][i];
			Rp=mk(R.in[a],R.in[r]);
		}
	}
}

namespace Solve
{
	using Btree::Dtree;using Btree::t1;using Btree::t2;
	ll sum,ans[M];
	int len,block[N<<1];
	struct data
	{
		int p1,p2,id,coef;
		data(){};
		data(int _p1,int _p2,int _id,int _opt){p1=_p1,p2=_p2,id=_id,coef=_opt;}
	};
	vector<data>qll,qlr,qrr;
	void insq(vector<data> &q,pii p1,pii p2,int id)
	{
		q.push_back(data(p1.se,p2.se,id,1));
		q.push_back(data(p1.se,p2.fi-1,id,-1));
		q.push_back(data(p1.fi-1,p2.se,id,-1));
		q.push_back(data(p1.fi-1,p2.fi-1,id,1));
	}
	void upd(int u,Dtree &t1,Dtree &t2,int coef)
	{
		sum+=coef*t2.query(u);
		t1.insert(u,coef);
	}
	void solve(vector<data> &q,pii *id1,pii *id2)
	{
		sort(q.begin(),q.end(),[&](data a,data b)
		{
			if(block[a.p1]!=block[b.p1]) return block[a.p1]<block[b.p1];
			return a.p2<b.p2;
		});
		int tmp1=0,tmp2=0;
		for(data now:q)
		{
			while(tmp1<now.p1) ++tmp1,upd(id1[tmp1].fi,t1,t2,id1[tmp1].se);
			while(tmp1>now.p1) upd(id1[tmp1].fi,t1,t2,-id1[tmp1].se),tmp1--;
			while(tmp2<now.p2) ++tmp2,upd(id2[tmp2].fi,t2,t1,id2[tmp2].se);
			while(tmp2>now.p2) upd(id2[tmp2].fi,t2,t1,-id2[tmp2].se),tmp2--;
			ans[now.id]+=now.coef*sum;
		}
		while(tmp1) upd(id1[tmp1].fi,t1,t2,-id1[tmp1].se),tmp1--;
		while(tmp2) upd(id2[tmp2].fi,t2,t1,-id2[tmp2].se),tmp2--;
	}
	void main()
	{
		len=125;
		for(int i=1,t=n<<1;i<=t;i++)
			block[i]=(i-1)/len+1;
		for(int i=1;i<=m;i++)
		{
			int a=read(),b=read(),c=read(),d=read();
			pii al,ar,bl,br;
			Atree::find(a,b,al,ar);
			Atree::find(c,d,bl,br);
			insq(qll,al,bl,i),insq(qlr,al,br,i),insq(qlr,bl,ar,i),insq(qrr,ar,br,i);
		}
		solve(qll,Atree::L.p,Atree::L.p);
		solve(qlr,Atree::L.p,Atree::R.p);
		solve(qrr,Atree::R.p,Atree::R.p);
		for(int i=1;i<=m;i++)
			printf("%lld\n",ans[i]);
	}
}

int main()
{
//	freopen("ex_segment3.in","r",stdin);
//	freopen("ex_segment3_my.out","w",stdout);
	n=read(),m=read();
	Atree::init();
	Btree::init();
	Solve::main();
	return 0;
}
posted @ 2022-10-30 12:19  ez_lcw  阅读(46)  评论(0编辑  收藏  举报