[ZJOI2017] 线段树

一、题目

点此看题

二、解法

利用 \(\tt zkw\) 的哨兵思想,我们先添加位置 \(0\)\(n+1\),略微地改动这棵线段树。对于询问 \([l,r]\),我们从 \(l-1\) 的叶子和 \(r+1\) 的叶子开始跳(分别记为 \(L,R\)),设它们的 \(\tt lca\)\(t\),那么定位的结果就是:

  • \(L\)\(t\) 左儿子链上,所有左儿子对应的右兄弟。
  • \(R\)\(t\) 右儿子链上,所有右儿子对应的左兄弟。

问题是算 \(u\) 到这些定位节点的距离和,转化为计算 \(\tt lca\) 的深度和,考虑这些点的分布是很有规律的(附着在树上的两条链上),那么分类讨论,按照 \(u\)\(t\) 的位置关系分为三类:

\(u=t\) 或者 \(u\)\(t\) 的子树外:那么所有定位点的 \(\tt lca\) 就是 \(lca(u,t)\),统计出个数即可。

\(u\)\(t\) 的左子树内:所有右子树定位点的 \(\tt lca\) 就是 \(t\),设 \(x=lca(L,u)\),在 \(x\) 以下的定位点的 \(\tt lca\)\(x\),在 \(x\) 以上的定位点的 \(\tt lca\) 是定位点的父亲,统计出个数及深度和即可。注意特判 \(u\)\(x\) 右子树的情况,因为这样那个端点的 \(\tt lca\)\(x\) 的右儿子。

\(u\)\(t\) 的右子树内:与 ② 本质相同。

那么方法就很简单了,只需要维护出到根链上的左右兄弟个数及其深度和,想用什么直接差分,时间复杂度 \(O(n\log n)\)

#include <cstdio>
#include <iostream>
using namespace std;
const int M = 400010;
#define int long long
int read()
{
	int x=0,f=1;char c;
	while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
	while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
	return x*f;
}
int n,m,ch[M][2],pos[M],dep[M],siz[M],id[M];
int Ind,f[M][20],cnt[M][2],sum[M][2];
int build(int l,int r)
{
	int u=++m;
	if(l==r) {pos[l]=u;return u;}
	int mid=read();
	ch[u][0]=build(l,mid);
	ch[u][1]=build(mid+1,r);
	return u;
}
void dfs1(int u,int fa)
{
	dep[u]=dep[fa]+1;
	siz[u]=1;id[u]=++Ind;
	for(int i=1;i<20;i++)
		f[u][i]=f[f[u][i-1]][i-1];
	for(int i=0;i<2;i++) if(ch[u][i])
	{
		f[ch[u][i]][0]=u;dfs1(ch[u][i],u);
		siz[u]+=siz[ch[u][i]];
	}
}
void dfs2(int u)
{
	for(int i=0;i<2;i++) if(ch[u][i])
	{
		int v=ch[u][i];
		for(int j=0;j<2;j++)
			sum[v][j]=sum[u][j],
			cnt[v][j]=cnt[u][j];
		if(ch[u][i^1])
			cnt[v][i^1]++,
			sum[v][i^1]+=dep[ch[u][i^1]];
		dfs2(v);
	}
}
int lca(int u,int v)
{
	if(dep[u]<dep[v]) swap(u,v);
	for(int i=19;i>=0;i--)
		if(dep[f[u][i]]>=dep[v])
			u=f[u][i];
	if(u==v) return u;
	for(int i=19;i>=0;i--)
		if(f[u][i]^f[v][i])
			u=f[u][i],v=f[v][i];
	return f[u][0];
}
signed main()
{
	n=read();build(1,n);
	ch[m+2][0]=pos[0]=m+1;ch[m+2][1]=1;
	ch[m+3][0]=m+2;ch[m+3][1]=pos[n+1]=m+4;
	dfs1(m+3,0);dfs2(m+3);
	m=read();
	while(m--)
	{
		int u=read(),l=read(),r=read();
		int t=lca(l=pos[l-1],r=pos[r+1]);
		int ls=ch[t][0],rs=ch[t][1],x=0;
		int ans=(sum[l][1]-sum[ls][1]+sum[r][0]-sum[rs][0])
		+dep[u]*(cnt[l][1]-cnt[ls][1]+cnt[r][0]-cnt[rs][0]);
		if(id[u]<=id[t] || id[u]>=id[t]+siz[t])
		// u out of the substree of t / or u=t
		{
			x=lca(t,u);
			ans-=2*dep[x]*(cnt[l][1]-cnt[ls][1]+cnt[r][0]-cnt[rs][0]);
		}
		else if(id[u]>=id[ls] && id[u]<id[ls]+siz[ls])
		// u in the left substree of t
		{
			x=lca(l,u);
			ans-=2*dep[t]*(cnt[r][0]-cnt[rs][0]);
			ans-=2*dep[x]*(cnt[l][1]-cnt[x][1]);
			ans-=2*((sum[x][1]-sum[ls][1])-(cnt[x][1]-cnt[ls][1]));
			id[u]>=id[ch[x][1]] && id[u]<id[ch[x][1]]+siz[ch[x][1]] && (ans-=2);
		}
		else
		// u in the right substree of t
		{
			x=lca(r,u);
			ans-=2*dep[t]*(cnt[l][1]-cnt[ls][1]);
			ans-=2*dep[x]*(cnt[r][0]-cnt[x][0]);
			ans-=2*((sum[x][0]-sum[rs][0])-(cnt[x][0]-cnt[rs][0]));
			id[u]>=id[ch[x][0]] && id[u]<id[ch[x][0]]+siz[ch[x][0]] && (ans-=2);
		}
		printf("%lld\n",ans);
	}
}
posted @ 2022-07-19 15:17  C202044zxy  阅读(128)  评论(0编辑  收藏  举报