【洛谷5439】【XR-2】永恒(树链剖分,线段树)
【洛谷5439】【XR-2】永恒(树链剖分,线段树)
题面
题解
首先两个点的\(LCP\)就是\(Trie\)树上的\(LCA\)的深度。
考虑一对点的贡献,如果这两个点不具有祖先关系,那么这对点被计算的次数是\(size[u]*size[v]\)次。否则具有祖先关系,假设\(u\)是\(v\)祖先,则是\(size[v]*(n-size[u]+1)\)次。
于是先考虑所有点不具有祖先关系,再减去有祖先关系的情况就好了。
然后现在知道了统计的次数,还需要知道统计的值,显然这个\(len\)可以从\(LCA\)到根节点在每个点都统计一次,那么就是每次链加链求和就行了。
怎么算就看代码吧。。。
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
using namespace std;
#define ll long long
#define MAX 300300
#define MOD 998244353
inline int read()
{
int x=0;bool t=false;char ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
if(ch=='-')t=true,ch=getchar();
while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
return t?-x:x;
}
int n,m,rt,ans,a[MAX];
vector<int> E[MAX],TE[MAX];
int dep[MAX],fa[MAX],sz[MAX],hson[MAX],top[MAX],dfn[MAX],tim;
void dfs1(int u,int ff)
{
fa[u]=ff;dep[u]=dep[ff]+1;sz[u]=1;
for(int v:TE[u])
{
dfs1(v,u),sz[u]+=sz[v];
if(sz[v]>sz[hson[u]])hson[u]=v;
}
}
void dfs2(int u,int tp)
{
top[u]=tp;dfn[u]=++tim;
if(hson[u])dfs2(hson[u],tp);
for(int v:TE[u])if(v!=hson[u])dfs2(v,v);
}
#define lson (now<<1)
#define rson (now<<1|1)
int sum[MAX<<2],tag[MAX<<2];
void Modify(int now,int l,int r,int L,int R,int w)
{
if(L<=l&&r<=R)
{
sum[now]=(sum[now]+1ll*(r-l+1)*w)%MOD;
tag[now]=(tag[now]+w)%MOD;
return;
}
int mid=(l+r)>>1;
if(L<=mid)Modify(lson,l,mid,L,R,w);
if(R>mid)Modify(rson,mid+1,r,L,R,w);
sum[now]=(sum[lson]+sum[rson]+1ll*tag[now]*(r-l+1))%MOD;
}
int Query(int now,int l,int r,int L,int R)
{
if(L==l&&r==R)return sum[now];
int mid=(l+r)>>1,ret=1ll*tag[now]*(R-L+1)%MOD;
if(R<=mid)return (ret+Query(lson,l,mid,L,R))%MOD;
if(L>mid)return (ret+Query(rson,mid+1,r,L,R))%MOD;
return (0ll+ret+Query(lson,l,mid,L,mid)+Query(rson,mid+1,r,mid+1,R))%MOD;
}
void Modify(int u,int w){while(u)Modify(1,1,m,dfn[top[u]],dfn[u],w),u=fa[top[u]];}
int Query(int u){int s=0;while(u)s=(s+Query(1,1,m,dfn[top[u]],dfn[u]))%MOD,u=fa[top[u]];return s;}
void dfs(int u){sz[u]=1;for(int v:E[u])dfs(v),sz[u]+=sz[v];}
void DFS(int u)
{
ans=(ans+1ll*sz[u]*Query(a[u])%MOD)%MOD;
Modify(a[u],MOD-sz[u]);
for(int v:E[u])
{
Modify(a[u],n-sz[v]);
DFS(v);
Modify(a[u],MOD-(n-sz[v]));
}
Modify(a[u],sz[u]);
}
int main()
{
n=read();m=read();
for(int i=1;i<=n;++i)E[read()].push_back(i);
for(int i=1;i<=m;++i)TE[read()].push_back(i);
scanf("%*s");
for(int i=1;i<=n;++i)a[i]=read();
for(int v:TE[1])dfs1(v,0),dfs2(v,v);
dfs(E[0][0]);
for(int i=1;i<=n;++i)ans=(ans+1ll*sz[i]*Query(a[i]))%MOD,Modify(a[i],sz[i]);
for(int i=1;i<=n;++i)Modify(a[i],MOD-sz[i]);
DFS(E[0][0]);
printf("%d\n",ans);
return 0;
}