P1501 [国家集训队]Tree II(LCT)
看着维护吧2333333
操作和维护区间加、乘线段树挺像的
进行修改操作时不要忘记吧每个点的点权$v[i]$也处理掉
还有就是$51061^2=2607225721>2147483647$
所以要开unsigned int
#include<iostream> #include<cstdio> #include<cstring> #define rint register int #define di unsigned int using namespace std; const int mod=51061; inline void Swap(int &a,int &b){a^=b^=a^=b;} void read(int &x){ char c=getchar();x=0; while(c<'0'||c>'9') c=getchar(); while('0'<=c&&c<='9') x=x*10+(c^48),c=getchar(); } inline int Md(int x){return x<mod?x:x-mod;} #define N 100005 int n,m,ch[N][2],fa[N],rev[N]; di v[N],add[N],s[N],mul[N],siz[N]; #define lc ch[x][0] #define rc ch[x][1] inline bool nrt(int x){return ch[fa[x]][0]==x||ch[fa[x]][1]==x;} inline void up(int x){ s[x]=Md(Md(s[lc]+s[rc])+v[x]), siz[x]=siz[lc]+siz[rc]+1; } inline void Rev(int x){Swap(lc,rc),rev[x]^=1;} void down(int x){ if(lc){ s[lc]=Md(s[lc]*mul[x]%mod+add[x]*siz[lc]%mod); add[lc]=Md(add[lc]*mul[x]%mod+add[x]); v[lc]=Md(v[lc]*mul[x]%mod+add[x]);// mul[lc]=mul[lc]*mul[x]%mod; } if(rc){ s[rc]=Md(s[rc]*mul[x]%mod+add[x]*siz[rc]%mod); add[rc]=Md(add[rc]*mul[x]%mod+add[x]); v[rc]=Md(v[rc]*mul[x]%mod+add[x]);// mul[rc]=mul[rc]*mul[x]%mod; }add[x]=0; mul[x]=1; if(rev[x]) Rev(lc),Rev(rc),rev[x]=0; } void Pre(int x){if(nrt(x))Pre(fa[x]); down(x);} void turn(int x){ int y=fa[x],z=fa[y],l=(ch[y][1]==x),r=l^1; if(nrt(y)) ch[z][ch[z][1]==y]=x; fa[ch[x][r]]=y; fa[y]=x; fa[x]=z; ch[y][l]=ch[x][r]; ch[x][r]=y; up(y); up(x); } void splay(int x){ Pre(x); for(;nrt(x);turn(x)){ int y=fa[x],z=fa[y]; if(nrt(y)) turn(((ch[y][1]==x)^(ch[z][1]==y))?x:y); } } void access(int x){for(int y=0;x;y=x,x=fa[x])splay(x),rc=y,up(x);} inline void makert(int x){access(x),splay(x),Rev(x);} int findrt(int x){ access(x);splay(x);down(x); while(lc) x=lc,down(x); splay(x); return x; } void link(int x,int y){makert(x); if(findrt(y)!=x) fa[x]=y;} void cut(int x,int y){ makert(x); if(findrt(y)==x&&fa[y]==x&&!ch[y][0]) fa[y]=rc=0,up(x); } inline void split(int x,int y){makert(x),access(y),splay(y);} int main(){ read(n);read(m); char opt[5]; int q1,q2,q3; for(rint i=1;i<=n;++i) siz[i]=v[i]=mul[i]=1; for(rint i=1;i<n;++i) read(q1),read(q2),link(q1,q2); while(m--){ scanf("%s",opt); if(opt[0]=='+'){ read(q1),read(q2),read(q3); split(q1,q2); s[q2]=Md(s[q2]+siz[q2]*q3); add[q2]=Md(add[q2]+q3); v[q2]=Md(v[q2]+q3);// }else if(opt[0]=='-'){ read(q1),read(q2),cut(q1,q2); read(q1),read(q2),link(q1,q2); }else if(opt[0]=='*'){ read(q1),read(q2),read(q3); split(q1,q2); s[q2]=s[q2]*q3%mod; v[q2]=v[q2]*q3%mod;// add[q2]=add[q2]*q3%mod; mul[q2]=mul[q2]*q3%mod; }else if(opt[0]=='/'){ read(q1),read(q2); split(q1,q2); printf("%d\n",s[q2]); } }return 0; }