树链剖分(+线段树)(codevs4633)
type node=^link; link=record des:longint; next:node; end; type seg=record z,y,lc,rc,toadd,sum:longint; end; var n,tot,i,t1,t2,q,a,b,c:longint; p:node; son,siz,dep,fa,num,top:array[1..100000] of longint; tr:array[0..250000] of seg; nd:array[1..100000] of node; function max(a,b:longint):longint; begin if a>b then exit(a) else exit(b); end; function min(a,b:longint):longint; begin if a>b then exit(b) else exit(a); end; procedure dfs1(po:longint); var p:node; begin siz[po]:=1;son[po]:=0; p:=nd[po]; while p<>nil do begin if dep[p^.des]=0 then begin dep[p^.des]:=dep[po]+1; fa[p^.des]:=po; dfs1(p^.des); if siz[p^.des]>siz[son[po]] then son[po]:=p^.des; siz[po]:=siz[po]+siz[p^.des]; end; p:=p^.next; end; end;//寻找非叶子结点中儿子siz最大,记录在son中 procedure dfs2(po,tp:longint); var p:node; begin inc(tot); num[po]:=tot; top[po]:=tp; if son[po]<>0 then dfs2(son[po],tp); p:=nd[po]; while p<>nil do begin if (p^.des<>son[po]) and (p^.des<>fa[po]) then dfs2(p^.des,p^.des); p:=p^.next; end; end;//将重边练成重链,num记录树上的点哈希到线段树上的结果 procedure buildtree(l,r:longint); var t:longint; begin inc(tot); tr[tot].sum:=0;tr[tot].toadd:=0; tr[tot].z:=l;tr[tot].y:=r; t:=tot; if l=r then exit else begin tr[t].lc:=tot+1; buildtree(l,(l+r) div 2); tr[t].rc:=tot+1; buildtree((l+r) div 2+1,r); end; end;//建线段树 procedure add(po,l,r,k:longint); var mid:longint; begin if tr[po].toadd<>0 then begin tr[po].sum:=tr[po].sum+(tr[po].y-tr[po].z+1)*tr[po].toadd; tr[tr[po].lc].toadd:=tr[tr[po].lc].toadd+tr[po].toadd; tr[tr[po].rc].toadd:=tr[tr[po].rc].toadd+tr[po].toadd; tr[po].toadd:=0; end; mid:=(tr[po].z+tr[po].y) div 2; tr[po].sum:=tr[po].sum+(r-l+1)*k; if (l=tr[po].z) and (r=tr[po].y) then begin tr[tr[po].lc].toadd:=tr[tr[po].lc].toadd+k; tr[tr[po].rc].toadd:=tr[tr[po].rc].toadd+k; exit; end else begin if mid>=l then add(tr[po].lc,l,min(mid,r),k); if r>mid then add(tr[po].rc,max(mid+1,l),r,k); end; end;//线段树加 function ans(po,l,r:longint):longint; var mid:longint; begin if tr[po].toadd<>0 then begin tr[po].sum:=tr[po].sum+(tr[po].y-tr[po].z+1)*tr[po].toadd; tr[tr[po].lc].toadd:=tr[tr[po].lc].toadd+tr[po].toadd; tr[tr[po].rc].toadd:=tr[tr[po].rc].toadd+tr[po].toadd; tr[po].toadd:=0; end; mid:=(tr[po].z+tr[po].y) div 2; if (l=tr[po].z) and (r=tr[po].y) then exit(tr[po].sum) else begin ans:=0; if mid>=l then ans:=ans+ans(tr[po].lc,l,min(mid,r)); if r>mid then ans:=ans+ans(tr[po].rc,max(mid+1,l),r); end; end;//线段树求和 procedure plus(b,c:longint); begin while top[b]<>top[c] do begin if dep[top[b]]<dep[top[c]] then begin add(1,num[top[c]],num[c],1); c:=fa[top[c]]; end else begin add(1,num[top[b]],num[b],1); b:=fa[top[b]]; end; end; if num[b]<num[c] then add(1,num[b],num[c],1) else add(1,num[c],num[b],1); end;//通过重链寻找被修改的区间 function query(b,c:longint):longint; begin query:=0; while top[b]<>top[c] do begin if dep[top[b]]<dep[top[c]] then begin query:=query+ans(1,num[top[c]],num[c]); c:=fa[top[c]]; end else begin query:=query+ans(1,num[top[b]],num[b]); b:=fa[top[b]]; end; end; if num[b]<num[c] then query:=query+ans(1,num[b],num[c]) else query:=query+ans(1,num[c],num[b]); end;//通过重链寻找被求和的区间 begin read(n); for i:=1 to n-1 do begin read(t1,t2); new(p); p^.des:=t2;p^.next:=nd[t1];nd[t1]:=p; new(p); p^.des:=t1;p^.next:=nd[t2];nd[t2]:=p; end; dep[1]:=1; dfs1(1); dfs2(1,1); tot:=0; buildtree(1,n); read(q); for i:=1 to q do begin read(a,b,c); if a=1 then plus(b,c); if a=2 then writeln(query(b,c)); end; end.
————————————————————————————————————————————————————————————————
c++(BZOJ1036)
#include <cstdio> #include <iostream> #define LL long long using namespace std; int next[60001],des[60001],nd[30001],bt[30001],son[30001],maxi[30001]; int fa[30001],dep[30001],size[30001],id[30001],top[30001],a[30001],revid[30001]; int cnt,n,q; struct node{ int l,r,lc,rc,maxi,sum; }tr[60001]; void swp(int &x,int &y){ int t=x;x=y;y=t; } void addedge(int x,int y){ next[++cnt]=nd[x];des[cnt]=y;nd[x]=cnt; next[++cnt]=nd[y];des[cnt]=x;nd[y]=cnt; } void dfs1(int po){ bt[po]=1; son[po]=-1;maxi[po]=-1; size[po]=1; for (int p=nd[po];p!=-1;p=next[p]) if (bt[des[p]]==0){ fa[des[p]]=po;dep[des[p]]=dep[po]+1; dfs1(des[p]); size[po]+=size[des[p]]; if (size[des[p]]>maxi[po]){ maxi[po]=size[des[p]]; son[po]=des[p]; } } } void dfs2(int po,int tp){ id[po]=++cnt;top[po]=tp; if (son[po]==-1) return; dfs2(son[po],tp); for (int p=nd[po];p!=-1;p=next[p]) if(des[p]!=fa[po]&&des[p]!=son[po]) dfs2(des[p],des[p]); } void update(int po){ tr[po].sum=tr[tr[po].lc].sum+tr[tr[po].rc].sum; tr[po].maxi=max(tr[tr[po].lc].maxi,tr[tr[po].rc].maxi); } void build(int l,int r){ tr[++cnt].l=l;tr[cnt].r=r; if (l==r) {tr[cnt].sum=tr[cnt].maxi=a[revid[l]];return;} int t=cnt,mid=(l+r)>>1; tr[t].lc=cnt+1; build(l,mid); tr[t].rc=cnt+1; build(mid+1,r); update(t); } void edi(int po,int targ){ if (tr[po].l==tr[po].r) {tr[po].sum=tr[po].maxi=a[targ];return;} int mid=(tr[po].l+tr[po].r>>1); if (targ<=mid) edi(tr[po].lc,targ);else edi(tr[po].rc,targ); update(po); } int getmax(int po,int l,int r){ if (l==tr[po].l&&r==tr[po].r) return(tr[po].maxi); int mid=(tr[po].l+tr[po].r)>>1; int ret=-1e9; if (l<=mid) ret=max(ret,getmax(tr[po].lc,l,min(mid,r))); if (r>mid) ret=max(ret,getmax(tr[po].rc,max(mid+1,l),r)); return(ret); } void QMAX(int x,int y){ int ans=-1e9; while (top[x]!=top[y]){ if (dep[top[x]]<dep[top[y]]) swp(x,y); ans=max(ans,getmax(1,id[top[x]],id[x])); x=fa[top[x]]; } if (dep[x]<dep[y]) swp(x,y); ans=max(ans,getmax(1,id[y],id[x])); printf("%d\n",ans); } int getsum(int po,int l,int r){ if (l==tr[po].l&&r==tr[po].r) return(tr[po].sum); int mid=(tr[po].l+tr[po].r)>>1; int ret=0; if (l<=mid) ret+=getsum(tr[po].lc,l,min(mid,r)); if (r>mid) ret+=getsum(tr[po].rc,max(mid+1,l),r); return(ret); } void QSUM(int x,int y){ int ans=0; while (top[x]!=top[y]){ if (dep[top[x]]<dep[top[y]]) swp(x,y); ans+=getsum(1,id[top[x]],id[x]); x=fa[top[x]]; } if (dep[x]<dep[y]) swp(x,y); ans+=getsum(1,id[y],id[x]); printf("%d\n",ans); } int main(){ scanf("%d",&n); for (int i=1;i<=n;i++) nd[i]=-1; for (int i=1;i<n;i++){ int x,y; scanf("%d%d",&x,&y); addedge(x,y); } dep[1]=1; dfs1(1); cnt=0; dfs2(1,1); for (int i=1;i<=n;i++) revid[id[i]]=i; for (int i=1;i<=n;i++) scanf("%d",&a[i]); cnt=0; build(1,n); scanf("%d",&q); char st[11]; for (int i=1;i<=q;i++){ scanf("%s",&st); int x,y; scanf("%d%d",&x,&y); if (st[1]=='M') QMAX(x,y); if (st[1]=='S') QSUM(x,y); if (st[1]=='H') a[id[x]]=y,edi(1,id[x]); } }
——————————————————————————————————
树链剖分可对每条链单独建立线段树以减小常数
#include <cstdio> #include <iostream> #define LL long long using namespace std; int next[200001],des[200001],nd[200001],cnt,size[200001],b[200001],fa[200001],dep[200001],son[200001]; int id[200001],rev[200001],top[200001],n,q,fr[200001],to[200001],root[200001],maxid[200001]; LL len[200001]; LL num[200001]; struct treenode{ int l,r,lc,rc; LL num; }tr[200001]; void addedge(int x,int y,LL num){ next[++cnt]=nd[x];des[cnt]=y;len[cnt]=num;nd[x]=cnt; next[++cnt]=nd[y];des[cnt]=x;len[cnt]=num;nd[y]=cnt; } void dfs1(int po){ size[po]=1;b[po]=1; int maxi=-1; for (int p=nd[po];p!=-1;p=next[p]) if (b[des[p]]==0){ num[des[p]]=len[p];fa[des[p]]=po; dep[des[p]]=dep[po]+1; dfs1(des[p]); if (size[des[p]]>maxi){ maxi=size[des[p]]; son[po]=des[p]; } size[po]+=size[des[p]]; } } void dfs2(int po,int tp){ id[po]=++cnt;rev[cnt]=po;top[po]=tp; if (son[po]) dfs2(son[po],tp); for (int p=nd[po];p!=-1;p=next[p]) if (des[p]!=fa[po]&&des[p]!=son[po]) dfs2(des[p],des[p]); } void update(LL &a,LL b,LL c){ if (b==-1||c==-1){ a=-1;return; } if (1e18/b<c){ a=-1;return; } a=b*c; } void build(int l,int r){ tr[++cnt].l=l;tr[cnt].r=r; if (l==r){ tr[cnt].num=num[rev[l]];return; } int mid=(l+r)>>1,t=cnt; tr[t].lc=cnt+1; build(l,mid); tr[t].rc=cnt+1; build(mid+1,r); update(tr[t].num,tr[tr[t].lc].num,tr[tr[t].rc].num); } void edi(int po,int tar,LL num){ if (tr[po].l==tr[po].r) {tr[po].num=num;return;} int mid=(tr[po].l+tr[po].r)>>1; if (tar<=mid) edi(tr[po].lc,tar,num);else edi(tr[po].rc,tar,num); update(tr[po].num,tr[tr[po].lc].num,tr[tr[po].rc].num); } LL getnum(int po,int l,int r){ LL ret=1; if (tr[po].l==l&&tr[po].r==r) return(tr[po].num); int mid=(tr[po].l+tr[po].r)>>1; if (l<=mid) update(ret,ret,getnum(tr[po].lc,l,min(mid,r))); if (r>mid) update(ret,ret,getnum(tr[po].rc,max(mid+1,l),r)); return(ret); } LL query(int x,int y){ LL ret=1; while (top[x]!=top[y]){ if (dep[top[x]]<dep[top[y]]){ int t=x;x=y;y=t; } LL t=getnum(root[top[x]],id[top[x]],id[x]); update(ret,ret,t);x=fa[top[x]]; } if (dep[x]<dep[y]){ int t=x;x=y;y=t; } if (x==y) return(ret); LL t=getnum(root[top[x]],id[son[y]],id[x]); update(ret,ret,t); return(ret); } int main(){ scanf("%d%d",&n,&q); for (int i=1;i<=n;i++) nd[i]=-1; for (int i=1;i<n;i++){ int t1,t2,t3; scanf("%d%d%lld",&fr[i],&to[i],&t3); addedge(fr[i],to[i],t3); } dep[1]=1; dfs1(1); cnt=0; dfs2(1,1); cnt=0; for (int i=1;i<=n;i++) maxid[top[i]]=max(maxid[top[i]],id[i]); for (int i=1;i<=n;i++) if (i==top[i]){ root[i]=cnt+1;build(id[i],maxid[i]); } for (int i=1;i<=q;i++){ int typ; scanf("%d",&typ); if (typ==1){ int x,y;LL v; scanf("%d%d%lld",&x,&y,&v); LL t=query(x,y); if (t==-1) printf("0\n");else printf("%lld\n",v/t); } if (typ==2){ int li;LL v; scanf("%d%lld",&li,&v); if (fa[fr[li]]==to[li]){ int t=fr[li];to[li]=fr[li];fr[li]=t; } edi(root[top[to[li]]],id[to[li]],v); } } }