洛谷 P1501 [国家集训队]Tree II Link-Cut-Tree
Code:
#include <cstdio> #include <algorithm> #include <cstring> #include <string> using namespace std; void setIO(string a) { freopen((a+".in").c_str(),"r",stdin); freopen((a+".out").c_str(),"w",stdout); } #define maxn 100009 #define ll long long #define mod 51061 int f[maxn], ch[maxn][2],siz[maxn],tag[maxn],sta[maxn],n,m; ll mult[maxn], add[maxn], sumv[maxn], val[maxn]; int lson(int x) { return ch[x][0]; } int rson(int x) { return ch[x][1]; } int get(int x) { return ch[f[x]][1]==x; } int isRoot(int x) { return !(ch[f[x]][1]==x||ch[f[x]][0]==x); } void mark(int x) { if(!x)return; swap(ch[x][0], ch[x][1]), tag[x]^=1; } void pushdown(int x){ // if(!x)return; if(mult[x]!=1) { if(lson(x)) { sumv[lson(x)]*=mult[x]; mult[lson(x)]*=mult[x]; add[lson(x)]*=mult[x]; val[lson(x)]*=mult[x]; add[lson(x)]%=mod; mult[lson(x)]%=mod; sumv[lson(x)]%=mod; val[lson(x)]%=mod; } if(rson(x)) { sumv[rson(x)]*=mult[x]; mult[rson(x)]*=mult[x]; add[rson(x)]*=mult[x]; val[rson(x)]*=mult[x]; add[rson(x)]%=mod; mult[rson(x)]%=mod; sumv[rson(x)]%=mod; val[rson(x)]%=mod; } mult[x]=1; } if(add[x]) { if(lson(x)) { sumv[lson(x)]+=add[x]*siz[lson(x)]; add[lson(x)]+=add[x]; val[lson(x)]+=add[x]; add[lson(x)]%=mod; sumv[lson(x)]%=mod; val[lson(x)]%=mod; } if(rson(x)) { sumv[rson(x)]+=add[x]*siz[rson(x)]; add[rson(x)]+=add[x]; val[rson(x)]+=add[x]; add[rson(x)]%=mod; sumv[rson(x)]%=mod; val[rson(x)]%=mod; } add[x]=0; } if(tag[x]) mark(ch[x][0]), mark(ch[x][1]), tag[x]=0; } void pushup(int x) { if(!x)return; siz[x]=siz[lson(x)]+siz[rson(x)]+1; sumv[x]=(sumv[lson(x)]+sumv[rson(x)]+val[x])%mod; } void rotate(int o){ int old=f[o],fold=f[old],which=get(o); if(!isRoot(old)) ch[fold][ch[fold][1]==old]=o; f[o]=fold; ch[old][which]=ch[o][which^1], f[ch[old][which]]=old; ch[o][which^1]=old,f[old]=o; pushup(old),pushup(o),pushup(fold); } void splay(int x){ int v=0,u=x; sta[++v]=u; while(!isRoot(u)) sta[++v]=f[u],u=f[u]; while(v) pushdown(sta[v--]); u=f[u]; for(int fa;(fa=f[x])!=u;rotate(x)) if(f[fa]!=u) rotate(get(x)==get(fa)?fa:x); } void Access(int x) { for(int y=0;x;y=x,x=f[x]) splay(x), ch[x][1]=y, pushup(x); } void makeRoot(int x) { Access(x), splay(x),mark(x); } void split(int x,int y) { makeRoot(x), Access(y), splay(y); } void Link(int x,int y) { makeRoot(x), f[x]=y; } void cut(int x,int y) { makeRoot(x), Access(y), splay(y); f[x] = ch[y][0] = 0, pushup(y); } void debug(){ for(int i=1;i<=n;++i) printf("%d %lld\n",siz[i],sumv[i]); } int main(){ //setIO("input"); memset(add,0,sizeof(add)), memset(mult,1,sizeof(mult)); scanf("%d%d",&n,&m); for(int i=1;i<=n;++i) sumv[i]=val[i]=1; for(int i=1;i<n;++i) { int a,b; scanf("%d%d",&a,&b); Link(a,b); } for(int i=1;i<=m;++i) { char opt[10]; int a,b,c,d; scanf("%s",opt); if(opt[0]=='+') { scanf("%d%d%d",&a,&b,&c); split(a,b); sumv[b]+=siz[b]*c; add[b]+=c; val[b]+=c; sumv[b]%=mod; add[b]%=mod; val[b]%=mod; } if(opt[0]=='-') { scanf("%d%d%d%d",&a,&b,&c,&d); cut(a,b); Link(c,d); } if(opt[0]=='*') { scanf("%d%d%d",&a,&b,&c); split(a,b); sumv[b]*=c; add[b]*=c; mult[b]*=c; val[b]*=c; sumv[b]%=mod; add[b]%=mod; mult[b]%=mod; val[b]%=mod; } if(opt[0]=='/') { scanf("%d%d",&a,&b); split(a,b); printf("%lld\n",sumv[b]%mod); } } return 0; }