树链剖分
luo 树剖模板题 https://www.luogu.org/problemnew/show/P3384
代码:
#include<cstdio> #include<algorithm> #include<cstring> using namespace std; const int maxn=100000+10; struct my{ int v,Next; }; typedef long long ll; my edge[maxn*2]; ll p; int adj[maxn],tot; ll w[maxn]; int dep[maxn],son[maxn],siz[maxn],fa[maxn],top[maxn]; ll tree[maxn<<2],add[maxn<<2]; int id[maxn],cnt; ll nw[maxn]; int n; void dfs1(int x,int f,int deep){ dep[x]=deep;//记录深度 siz[x]=1;//记录子树大小 fa[x]=f;//记录父亲节点 int maxx=-1; for (int i=adj[x];i;i=edge[i].Next){ int v=edge[i].v; if(v==f) continue; dfs1(v,x,deep+1); siz[x]+=siz[v]; if(maxx<siz[v]) son[x]=v,maxx=siz[v];//更新最大子树,即找重儿子 } } void dfs2(int x,int f,int topf){ id[x]=++cnt;//给节点重新编号,因为每个子树节点以及每条重链编号都是有序的,所以可以线段树维护 nw[cnt]=w[x]; top[x]=topf;//记录链的顶端 if(!son[x]) return ; dfs2(son[x],x,topf); for (int i=adj[x];i;i=edge[i].Next){ int v=edge[i].v; if(v==f||v==son[x]) continue; dfs2(v,x,v);//v,v是因为每个轻儿子是一条链的顶部 } } void myinsert(int u,int v){ edge[++tot].v=v; edge[tot].Next=adj[u]; adj[u]=tot; } void pushdown(int x,int ln,int rn){ if(add[x]){ add[x<<1]=(add[x<<1]+add[x])%p; add[x<<1|1]=(add[x<<1|1]+add[x])%p; tree[x<<1]=(tree[x<<1]+add[x]*ln%p)%p; tree[x<<1|1]=(tree[x<<1|1]+add[x]*rn%p)%p; add[x]=0; } } void pushup(int x){ tree[x]=(tree[x<<1]+tree[x<<1|1])%p; } void build(int l,int r,int rt){ if(l==r){ tree[rt]=nw[l]; return ; } int mid=(l+r)>>1; build(l,mid,rt<<1); build(mid+1,r,rt<<1|1); pushup(rt); } void change(int l,int r,int rt,int L,int R,ll c){ if(l>=L&&r<=R){ add[rt]=(add[rt]+c)%p; tree[rt]=(tree[rt]+(r-l+1)*c%p)%p; return ; } int mid=(l+r)>>1; pushdown(rt,mid-l+1,r-mid); if(L<=mid) change(l,mid,rt<<1,L,R,c); if(R>mid) change(mid+1,r,rt<<1|1,L,R,c); pushup(rt); } ll getans(int l,int r,int rt,int L,int R){ if(l>=L&&r<=R){ return tree[rt]; } int mid=(l+r)>>1; pushdown(rt,mid-l+1,r-mid); ll ans=0; if(L<=mid) ans=(ans+getans(l,mid,rt<<1,L,R))%p; if(R>mid) ans=(ans+getans(mid+1,r,rt<<1|1,L,R))%p; return ans%p; } ll getans1(int x,int y){ ll ans=0; while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); ans=(ans+getans(1,n,1,id[top[x]],id[x]))%p; x=fa[top[x]]; } if(dep[x]>dep[y]) swap(x,y); ans=(ans+getans(1,n,1,id[x],id[y]))%p; return ans%p; } void change1(int x,int y,ll z){ while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y);//相当于一直往上跳到两个点在同一条链为止,可画图模拟一波 change(1,n,1,id[top[x]],id[x],z); x=fa[top[x]]; } if(dep[x]>dep[y]) swap(x,y); change(1,n,1,id[x],id[y],z); } ll getans2(int x){ return getans(1,n,1,id[x],id[x]+siz[x]-1)%p;//因为节点是有序的,所以求子树大小就可以直接加了 } void change2(int x,ll z){ change(1,n,1,id[x],id[x]+siz[x]-1,z); } int main(){ freopen("read.in","r",stdin); int m,root; scanf("%d%d%d%lld",&n,&m,&root,&p); for (int i=1;i<=n;i++){ scanf("%lld",&w[i]); } for (int i=1;i<n;i++){ int u,v; scanf("%d%d",&u,&v); myinsert(u,v); myinsert(v,u); } dfs1(root,root,1); dfs2(root,root,root); build(1,n,1); int opt,x,y; ll z; while(m--){ scanf("%d",&opt); if(opt==1){ scanf("%d%d%lld",&x,&y,&z); change1(x,y,z); } if(opt==2){ scanf("%d%d",&x,&y); printf("%lld\n",getans1(x,y)); } if(opt==3){ scanf("%d%lld",&x,&z); change2(x,z); } if(opt==4){ scanf("%d",&x); printf("%lld\n",getans2(x)); } } return 0; }
#include<cstdio> #include<algorithm> #include<cstring> using namespace std; struct my{ int v,Next; }; struct node{ int l,r,max,sum; }; const int maxn=100000+10; int adj[maxn],tot,n; char ch[50]; int w[maxn],c[maxn],dep[maxn],fa[maxn],son[maxn],siz[maxn]; int top[maxn],cnt,len,mp[maxn],id[maxn]; int root[maxn]; my edge[maxn<<1]; node tree[maxn*25]; void myinsert(int u,int v){ edge[++tot].v=v; edge[tot].Next=adj[u]; adj[u]=tot; } void dfs1(int x,int f,int deep){ siz[x]=1; dep[x]=deep; fa[x]=f; int maxx=-1; for (int i=adj[x];i;i=edge[i].Next){ int v=edge[i].v; if(v==f) continue; dfs1(v,x,deep+1); siz[x]+=siz[v]; if(siz[v]>maxx) maxx=siz[v],son[x]=v; } } void dfs2(int x,int f,int topf){ id[x]=++cnt; mp[cnt]=x; top[x]=topf; if(!son[x]) return ; dfs2(son[x],x,topf); for (int i=adj[x];i;i=edge[i].Next){ int v=edge[i].v; if(v==f||v==son[x]) continue; dfs2(v,x,v); } } void pushup(int x){ tree[x].max=max(tree[tree[x].l].max,tree[tree[x].r].max); tree[x].sum=tree[tree[x].l].sum+tree[tree[x].r].sum; } void update(int &rt,int l,int r,int pos,int C){ if(!rt) rt=++len; if(l==r){ tree[rt].max=C; tree[rt].sum=C; return ; } int mid=(l+r)>>1; if(pos<=mid) update(tree[rt].l,l,mid,pos,C); else update(tree[rt].r,mid+1,r,pos,C); pushup(rt); } void del(int &rt,int l,int r,int pos){ if(!rt) return ; if(l==r){ tree[rt].max=tree[rt].sum=0; return ; } int mid=(l+r)>>1; if(pos<=mid) del(tree[rt].l,l,mid,pos); else del(tree[rt].r,mid+1,r,pos); pushup(rt); } int getmax(int &rt,int l,int r,int L,int R){ if(!rt) return 0; if(l>=L&&r<=R){ return tree[rt].max; } int ans=0; int mid=(l+r)>>1; if(L<=mid) ans=max(getmax(tree[rt].l,l,mid,L,R),ans); if(R>mid) ans=max(ans,getmax(tree[rt].r,mid+1,r,L,R)); return ans; } int getsum(int &rt,int l,int r,int L,int R){ if(!rt) return 0; if(l>=L&&r<=R){ return tree[rt].sum; } int mid=(l+r)>>1; int ans=0; if(L<=mid) ans+=getsum(tree[rt].l,l,mid,L,R); if(R>mid) ans+=getsum(tree[rt].r,mid+1,r,L,R); return ans; } void change1(int x,int p){ del(root[c[x]],1,n,id[x]); c[x]=p; update(root[c[x]],1,n,id[x],w[x]); } void change2(int x,int p){ del(root[c[x]],1,n,id[x]); w[x]=p; update(root[c[x]],1,n,id[x],w[x]); } int getans1(int x,int y){ int ans=0; int k=c[x]; while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); ans+=getsum(root[k],1,n,id[top[x]],id[x]); x=fa[top[x]]; } if(dep[x]>dep[y]) swap(x,y); ans+=getsum(root[k],1,n,id[x],id[y]); return ans; } int getans2(int x,int y){ int ans=0; int k=c[x]; while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); ans=max(ans,getmax(root[k],1,n,id[top[x]],id[x])); x=fa[top[x]]; } if(dep[x]>dep[y]) swap(x,y); ans=max(ans,getmax(root[k],1,n,id[x],id[y])); return ans; } int main(){ //freopen("read.in","r",stdin); int u,v; int m; scanf("%d%d",&n,&m); for (int i=1;i<=n;i++){ scanf("%d%d",&w[i],&c[i]); } for (int i=1;i<n;i++){ scanf("%d%d",&u,&v); myinsert(u,v); myinsert(v,u); } dfs1(1,1,1); dfs2(1,1,1); for (int i=1;i<=n;i++){ update(root[c[mp[i]]],1,n,i,w[mp[i]]); } int x,p; while(m--){ scanf("%s",ch); if(ch[1]=='C'){ scanf("%d%d",&x,&p); change1(x,p); } if(ch[1]=='W'){ scanf("%d%d",&x,&p); change2(x,p); } if(ch[1]=='S'){ scanf("%d%d",&x,&p); printf("%d\n",getans1(x,p)); } if(ch[1]=='M'){ scanf("%d%d",&x,&p); printf("%d\n",getans2(x,p)); } } return 0; }