【洛谷5439】【XR-2】永恒(树链剖分,线段树)

【洛谷5439】【XR-2】永恒(树链剖分,线段树)

题面

洛谷

题解

首先两个点的\(LCP\)就是\(Trie\)树上的\(LCA\)的深度。
考虑一对点的贡献,如果这两个点不具有祖先关系,那么这对点被计算的次数是\(size[u]*size[v]\)次。否则具有祖先关系,假设\(u\)\(v\)祖先,则是\(size[v]*(n-size[u]+1)\)次。
于是先考虑所有点不具有祖先关系,再减去有祖先关系的情况就好了。
然后现在知道了统计的次数,还需要知道统计的值,显然这个\(len\)可以从\(LCA\)到根节点在每个点都统计一次,那么就是每次链加链求和就行了。
怎么算就看代码吧。。。

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
using namespace std;
#define ll long long
#define MAX 300300
#define MOD 998244353
inline int read()
{
	int x=0;bool t=false;char ch=getchar();
	while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
	if(ch=='-')t=true,ch=getchar();
	while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
	return t?-x:x;
}
int n,m,rt,ans,a[MAX];
vector<int> E[MAX],TE[MAX];
int dep[MAX],fa[MAX],sz[MAX],hson[MAX],top[MAX],dfn[MAX],tim;
void dfs1(int u,int ff)
{
	fa[u]=ff;dep[u]=dep[ff]+1;sz[u]=1;
	for(int v:TE[u])
	{
		dfs1(v,u),sz[u]+=sz[v];
		if(sz[v]>sz[hson[u]])hson[u]=v;
	}
}
void dfs2(int u,int tp)
{
	top[u]=tp;dfn[u]=++tim;
	if(hson[u])dfs2(hson[u],tp);
	for(int v:TE[u])if(v!=hson[u])dfs2(v,v);
}
#define lson (now<<1)
#define rson (now<<1|1)
int sum[MAX<<2],tag[MAX<<2];
void Modify(int now,int l,int r,int L,int R,int w)
{
	if(L<=l&&r<=R)
	{
		sum[now]=(sum[now]+1ll*(r-l+1)*w)%MOD;
		tag[now]=(tag[now]+w)%MOD;
		return;
	}
	int mid=(l+r)>>1;
	if(L<=mid)Modify(lson,l,mid,L,R,w);
	if(R>mid)Modify(rson,mid+1,r,L,R,w);
	sum[now]=(sum[lson]+sum[rson]+1ll*tag[now]*(r-l+1))%MOD;
}
int Query(int now,int l,int r,int L,int R)
{
	if(L==l&&r==R)return sum[now];
	int mid=(l+r)>>1,ret=1ll*tag[now]*(R-L+1)%MOD;
	if(R<=mid)return (ret+Query(lson,l,mid,L,R))%MOD;
	if(L>mid)return (ret+Query(rson,mid+1,r,L,R))%MOD;
	return (0ll+ret+Query(lson,l,mid,L,mid)+Query(rson,mid+1,r,mid+1,R))%MOD;
}
void Modify(int u,int w){while(u)Modify(1,1,m,dfn[top[u]],dfn[u],w),u=fa[top[u]];}
int Query(int u){int s=0;while(u)s=(s+Query(1,1,m,dfn[top[u]],dfn[u]))%MOD,u=fa[top[u]];return s;}
void dfs(int u){sz[u]=1;for(int v:E[u])dfs(v),sz[u]+=sz[v];}
void DFS(int u)
{
	ans=(ans+1ll*sz[u]*Query(a[u])%MOD)%MOD;
	Modify(a[u],MOD-sz[u]);
	for(int v:E[u])
	{
		Modify(a[u],n-sz[v]);
		DFS(v);
		Modify(a[u],MOD-(n-sz[v]));
	}
	Modify(a[u],sz[u]);
}
int main()
{
	n=read();m=read();
	for(int i=1;i<=n;++i)E[read()].push_back(i);
	for(int i=1;i<=m;++i)TE[read()].push_back(i);
	scanf("%*s");
	for(int i=1;i<=n;++i)a[i]=read();
	for(int v:TE[1])dfs1(v,0),dfs2(v,v);
	dfs(E[0][0]);
	for(int i=1;i<=n;++i)ans=(ans+1ll*sz[i]*Query(a[i]))%MOD,Modify(a[i],sz[i]);
	for(int i=1;i<=n;++i)Modify(a[i],MOD-sz[i]);
	DFS(E[0][0]);
	printf("%d\n",ans);
	return 0;
}
posted @ 2019-07-01 17:30  小蒟蒻yyb  阅读(528)  评论(0编辑  收藏  举报