Luogu「StOI-2」简单的树 树链剖分+线段树+倍增
考场的时候智障了,写了 6k+ 的树链剖分.
如果题目带修改的话可以用树链剖分来维护,但由于没有修改用一个前缀和其实就够了.
求 $\sum_{i=l}^{r} f(a,i)$ 可以写成两个前缀相减的形式.
然后我们就要求 $\sum_{i=0}^{r} f(a,i)$.
求这个的话用倍增讨论 $a$ 的初始值的影响范围,因为在影响范围内刚开始都是由子树中次大值来贡献.
然后这个次大值显然单调,我们就可以找到贡献会比次大值大的临界点,贡献是一个等差数列的形式,维护平方和以及区间和即可.
然后对于 $a$ 的初始值贡献不到的地方也这么讨论一下即可.
如果加上一个带修改还真的挺毒瘤的,不过反正考场上总共花了 60 多分钟就过掉了.
代码:
#include <cstdio> #include <cstring> #include <algorithm> #define N 500009 #define ll long long #define mod 998244353 #define lson now<<1 #define rson now<<1|1 #define setIO(s) freopen(s".in","r",stdin) using namespace std; ll lastans; int edges,n,Q,OPT,tim; int hd[N],to[N<<1],nex[N<<1],val[N]; int fa[N],size[N],son[N],dfn[N],bu[N],f[20][N],dep[N],top[N]; void add(int u,int v) { nex[++edges]=hd[u]; hd[u]=edges,to[edges]=v; } int DECODE(int x) { ll y=1ll*x+1ll*OPT*lastans; y%=n; ++y; return (int)y; } struct data { ll sqr,sum; data() { sqr=sum=0; } data operator+(const data b) const { data c; c.sqr=sqr+b.sqr; c.sum=sum+b.sum; return c; } }; struct node { data se,mx; node operator+(const node b) const { node c; c.se=se+b.se; c.mx=mx+b.mx; return c; } }s[N<<2]; struct Tree { int mx,se; Tree(int mx=0,int se=0):mx(mx),se(se){} Tree operator+(const Tree b) const { Tree c; c.mx=c.se=0; if(mx<b.mx) { c.mx=b.mx; c.se=max(mx,b.se); } if(mx>b.mx) { c.mx=mx; c.se=max(se,b.mx); } if(mx==b.mx) { c.se=c.mx=mx; } return c; } }tree[N]; void dfs0(int x,int ff) { fa[x]=ff,size[x]=1; dep[x]=dep[ff]+1; f[0][x]=fa[x]; tree[x]=Tree(val[x],0); for(int i=hd[x];i;i=nex[i]) { int y=to[i]; if(y==ff) continue; dfs0(y,x); size[x]+=size[y]; if(size[y]>size[son[x]]) son[x]=y; tree[x]=tree[x]+tree[y]; } } void dfs1(int x,int tp) { top[x]=tp; dfn[x]=++tim; bu[tim]=x; if(son[x]) { dfs1(son[x],tp); } for(int i=hd[x];i;i=nex[i]) { int y=to[i]; if(y==fa[x]||y==son[x]) continue; dfs1(y,y); } } void build(int l,int r,int now) { if(l==r) { int cur=bu[l]; s[now].mx.sum=tree[cur].mx; s[now].mx.sqr=1ll*tree[cur].mx*tree[cur].mx; s[now].se.sum=tree[cur].se; s[now].se.sqr=1ll*tree[cur].se*tree[cur].se; return; } int mid=(l+r)>>1; build(l,mid,lson),build(mid+1,r,rson); s[now]=s[lson]+s[rson]; } node query(int l,int r,int now,int L,int R) { if(l>=L&&r<=R) { return s[now]; } int mid=(l+r)>>1; if(L<=mid&&R>mid) return query(l,mid,lson,L,R)+query(mid+1,r,rson,L,R); else if(L<=mid) return query(l,mid,lson,L,R); else return query(mid+1,r,rson,L,R); } node Query(int x,int y) { node re; while(top[x]!=top[y]) { if(dep[top[x]]>dep[top[y]]) { re=re+query(1,n,1,dfn[top[x]],dfn[x]); x=fa[top[x]]; } else { re=re+query(1,n,1,dfn[top[y]],dfn[y]); y=fa[top[y]]; } } if(dep[x]>dep[y]) { swap(x,y); } re=re+query(1,n,1,dfn[x],dfn[y]); return re; } ll solve(int x,int r) { if(r<0) return 0; // 极长最大值小于等于 val[x] 的 int tar=x; for(int i=19;i>=0;--i) { if(!f[i][tar]) continue; // tree[f[i][tar]].mx<=val[x] if(tree[f[i][tar]].mx<=val[x]) { tar=f[i][tar]; } } ll ans=0; if(tree[tar].mx<=val[x]) { // 存在这么一段 // x -> tar 这一段 // 先变成 0,故这一段的贡献先是 int pr=x; for(int i=19;i>=0;--i) { if(!f[i][pr]||dep[f[i][pr]]<dep[tar]) continue; if(tree[f[i][pr]].se<r) pr=f[i][pr]; } if(tree[pr].se<r) { // 等差数列求和 node e=Query(x,pr); ans+=e.se.sum*1ll*(r+1)%mod; // 共 r+1 个时刻 int num=dep[x]-dep[pr]+1; ll tm=1ll*r*r*num-2ll*r*e.se.sum+1ll*r*num+e.se.sqr-e.se.sum; ans+=tm/2; pr=fa[pr]; } // 这部分算好了 if(dep[pr]>=dep[tar]) { // 永远都不变的大哥 node e=Query(pr,tar); ans+=e.se.sum*1ll*(r+1)%mod; ans%=mod; } tar=fa[tar]; } if(tar) { // 其余是要依靠 r 来改变的 node e=Query(tar,1); ans+=e.mx.sum*1ll*(r+1); int pr=tar; for(int i=19;i>=0;--i) { if(!f[i][pr]) continue; if(tree[f[i][pr]].mx<r) pr=f[i][pr]; } if(tree[pr].mx<r) { e=Query(tar,pr); int num=dep[tar]-dep[pr]+1; ll tm=1ll*r*r*num-2ll*r*e.mx.sum+1ll*r*num+e.mx.sqr-e.mx.sum; ans+=tm/2; ans%=mod; } } return ans%mod; } char buf[100000],*p1,*p2; #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++) int rd() { int x=0; char s=nc(); while(s<'0') s=nc(); while(s>='0') x=(((x<<2)+x)<<1)+s-'0',s=nc(); return x; } int main() { /// setIO("input"); int x,y,z; n=rd(),Q=rd(),OPT=rd(); for(int i=1;i<=n;++i) { val[i]=rd(); } for(int i=1;i<n;++i) { x=rd(),y=rd(); add(x,y),add(y,x); } dfs0(1,0); dfs1(1,1); build(1,n,1); for(int i=1;i<19;++i) { for(int j=1;j<=n;++j) { f[i][j]=f[i-1][f[i-1][j]]; } } ll fin=s[1].mx.sum; for(int i=1;i<=Q;++i) { int l=rd(),r=rd(),a=rd(); l=DECODE(l),r=DECODE(r),a=DECODE(a); if(l>r) { swap(l,r); } node e=Query(1,a); ll cur=(fin-1ll*e.mx.sum%mod)*(r-l+1)%mod; printf("%lld\n",lastans=(ll)(cur+solve(a,r)-solve(a,l-1)+mod)%mod); } return 0; }