树链剖分学习
两次dfs,第一次处理处fa[],depth[],size[],son[],第二次处理出top[],rank[],id[]
一条重链的编号是连续的,可以用数据结构维护,做事情的时候判断是否在同一条链上,不是就把最深的跳到链头的fa[],然后继续判断
luogu3384 【模板】树链剖分
区间加,求和,子树加,求和
#include<iostream> #include<cstdio> #include<cstdlib> #include<cstring> #include<cmath> #include<algorithm> #define ll long long using namespace std; const int maxn=200005; inline int read(){ int x=0,k=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-') k=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();} return k*x; } int rt,cnt,ecnt,fir[maxn],nex[maxn],to[maxn]; int f[maxn],d[maxn],size[maxn],son[maxn],rk[maxn],top[maxn],id[maxn],val[maxn]; struct node{ int l,r; ll v,add; }a[maxn*2]; int n,m,R,opt,x,y; ll p,k; void build(int u,int l,int r){ a[u].l=l;a[u].r=r;a[u].add=0; if(l==r) a[u].v=val[rk[l]]; else{ int mid=l+r>>1; build(u<<1,l,mid); build(u<<1|1,mid+1,r); a[u].v=a[u<<1].v+a[u<<1|1].v; } } void addd(int u,ll v){ a[u].add+=v; a[u].add%=p; a[u].v+=(a[u].r-a[u].l+1)*v; a[u].v%=p; } void pushdown(int u){ if(a[u].add){ addd(u<<1,a[u].add); addd(u<<1|1,a[u].add); a[u].add=0; } } void update(int u,int l,int r,ll v){ if(a[u].l==l&&a[u].r==r) addd(u,v); else{ int mid=(a[u].l+a[u].r)>>1; pushdown(u); if(r<=mid) update(u<<1,l,r,v); else if(l>mid) update(u<<1|1,l,r,v); else update(u<<1,l,mid,v),update(u<<1|1,mid+1,r,v); a[u].v=a[u<<1].v+a[u<<1|1].v; } } ll query(int u,int l,int r){ if(a[u].l==l&&a[u].r==r) return a[u].v; int mid=(a[u].l+a[u].r)>>1; pushdown(u); if(r<=mid) return query(u<<1,l,r); else if(l>mid) return query(u<<1|1,l,r); return (query(u<<1,l,mid)+query(u<<1|1,mid+1,r))%p; } void addedge(int u,int v){ nex[++ecnt]=fir[u];fir[u]=ecnt;to[ecnt]=v; } void dfs1(int u,int fa,int depth){ f[u]=fa; d[u]=depth; size[u]=1; for(int i=fir[u];i;i=nex[i]){ int v=to[i]; if(v==fa) continue; dfs1(v,u,depth+1); size[u]+=size[v]; if(size[v]>size[son[u]]) son[u]=v; } } void dfs2(int u,int t){ top[u]=t; id[u]=++cnt; rk[cnt]=u; if(!son[u]) return; dfs2(son[u],t); for(int i=fir[u];i;i=nex[i]){ int v=to[i]; if(v!=son[u]&&v!=f[u]) dfs2(v,v); } } int sum(int x,int y){ int ans=0,fx=top[x],fy=top[y]; while(fx!=fy){ if(d[fx]>=d[fy]){ ans+=query(1,id[fx],id[x]); ans%=p; x=f[fx]; } else{ ans+=query(1,id[fy],id[y]); ans%=p; y=f[fy]; } fx=top[x]; fy=top[y]; } if(id[x]<=id[y]) ans+=query(1,id[x],id[y]),ans%=p; else ans+=query(1,id[y],id[x]),ans%=p; } void updates(int x,int y,int c){ int fx=top[x],fy=top[y]; while(fx!=fy){ if(d[fx]>=d[fy]){ update(1,id[fx],id[x],c); x=f[fx]; } else{ update(1,id[fy],id[y],c); y=f[fy]; } fx=top[x]; fy=top[y]; } if(id[x]<=id[y]) update(1,id[x],id[y],c); else update(1,id[y],id[x],c); } int main(){ // freopen(".in","r",stdin); // freopen(".out","w",stdout); cin>>n>>m>>R>>p; for(int i=1;i<=n;i++){ val[i]=read(); } for(int i=1;i<n;i++){ int x,y; x=read();y=read(); addedge(x,y); addedge(y,x); } cnt=0; dfs1(R,0,1); dfs2(R,R); cnt=1; build(1,1,n); for(int i=1;i<=m;i++){ int opt,x,y,z; opt=read(); if(opt==1){ x=read();y=read();z=read();z%=p; updates(x,y,z); } else if(opt==2){ x=read();y=read(); cout<<sum(x,y)<<endl; } else if(opt==3){ x=read();z=read(); update(1,id[x],id[x]+size[x]-1,z%p); } else{ x=read(); cout<<query(1,id[x],id[x]+size[x]-1)<<endl; } } return 0; }
相当于n-1次+1,每次把a[i]~a[i+1]的区间+1,由于a[2]~a[n]会加两次,所以这些还要-1
// luogu-judger-enable-o2 #include<iostream> #include<cstdio> #include<cstdlib> #include<cstring> #include<cmath> #include<algorithm> #define ll long long using namespace std; const int maxn=300005; const int maxm=600005; inline int read(){ int x=0,k=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-') k=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();} return k*x; } int rt,cnt,ecnt,fir[maxn],nex[maxm],to[maxm]; int f[maxn],d[maxn],size[maxn],son[maxn],rk[maxn],top[maxn],id[maxn],val[maxn]; struct node{ int l,r; ll v,add; }a[maxm*2]; int n,m,R,opt,x,y; ll p,k; void build(int u,int l,int r){ a[u].l=l;a[u].r=r;a[u].add=0; if(l==r) return; else{ int mid=l+r>>1; build(u<<1,l,mid); build(u<<1|1,mid+1,r); a[u].v=a[u<<1].v+a[u<<1|1].v; } } void addd(int u,ll v){ a[u].add+=v; a[u].v+=(a[u].r-a[u].l+1)*v; } void pushdown(int u){ if(a[u].add){ addd(u<<1,a[u].add); addd(u<<1|1,a[u].add); a[u].add=0; } } void update(int u,int l,int r,ll v){ if(a[u].l==l&&a[u].r==r) addd(u,v); else{ int mid=(a[u].l+a[u].r)>>1; pushdown(u); if(r<=mid) update(u<<1,l,r,v); else if(l>mid) update(u<<1|1,l,r,v); else update(u<<1,l,mid,v),update(u<<1|1,mid+1,r,v); a[u].v=a[u<<1].v+a[u<<1|1].v; } } ll query(int u,int l,int r){ if(a[u].l==l&&a[u].r==r) return a[u].v; int mid=(a[u].l+a[u].r)>>1; pushdown(u); if(r<=mid) return query(u<<1,l,r); else if(l>mid) return query(u<<1|1,l,r); return (query(u<<1,l,mid)+query(u<<1|1,mid+1,r)); } void addedge(int u,int v){ nex[++ecnt]=fir[u];fir[u]=ecnt;to[ecnt]=v; } void dfs1(int u,int fa,int depth){ f[u]=fa; d[u]=depth; size[u]=1; for(int i=fir[u];i;i=nex[i]){ int v=to[i]; if(v==fa) continue; dfs1(v,u,depth+1); size[u]+=size[v]; if(size[v]>size[son[u]]) son[u]=v; } } void dfs2(int u,int t){ top[u]=t; id[u]=++cnt; rk[cnt]=u; if(!son[u]) return; dfs2(son[u],t); for(int i=fir[u];i;i=nex[i]){ int v=to[i]; if(v!=son[u]&&v!=f[u]) dfs2(v,v); } } int sum(int x,int y){ int ans=0,fx=top[x],fy=top[y]; while(fx!=fy){ if(d[fx]>=d[fy]){ ans+=query(1,id[fx],id[x]); ans%=p; x=f[fx]; } else{ ans+=query(1,id[fy],id[y]); ans%=p; y=f[fy]; } fx=top[x]; fy=top[y]; } if(id[x]<=id[y]) ans+=query(1,id[x],id[y]); else ans+=query(1,id[y],id[x]); } void updates(int x,int y,int c){ int fx=top[x],fy=top[y]; while(fx!=fy){ if(d[fx]>=d[fy]){ update(1,id[fx],id[x],c); x=f[fx]; } else{ update(1,id[fy],id[y],c); y=f[fy]; } fx=top[x]; fy=top[y]; } if(id[x]<=id[y]) update(1,id[x],id[y],c); else update(1,id[y],id[x],c); } int main(){ // freopen(".in","r",stdin); // freopen(".out","w",stdout); n=read(); for(int i=1;i<=n;i++) val[i]=read(); for(int i=1;i<n;i++){ int x,y; x=read();y=read(); addedge(x,y);addedge(y,x); } dfs1(1,0,1); dfs2(1,1); build(1,1,n); for(int i=1;i<n;i++){ updates(val[i],val[i+1],1); updates(val[i+1],val[i+1],-1); } for(int i=1;i<=n;i++){ cout<<query(1,id[i],id[i])<<endl; } return 0; }