树链剖分(+线段树)(codevs4633)

type node=^link;
  link=record
    des:longint;
    next:node;
     end;

type seg=record
     z,y,lc,rc,toadd,sum:longint;
  end;

var
 n,tot,i,t1,t2,q,a,b,c:longint;
 p:node;
 son,siz,dep,fa,num,top:array[1..100000] of longint;
 tr:array[0..250000] of seg;
 nd:array[1..100000] of node;

 function max(a,b:longint):longint;
   begin
     if a>b then exit(a) else exit(b);
   end;

 function min(a,b:longint):longint;
   begin
     if a>b then exit(b) else exit(a);
   end;

 procedure dfs1(po:longint);
 var
  p:node;
   begin
     siz[po]:=1;son[po]:=0;
     p:=nd[po];
     while p<>nil do
       begin
         if dep[p^.des]=0 then
           begin
             dep[p^.des]:=dep[po]+1;
             fa[p^.des]:=po;
             dfs1(p^.des);
             if siz[p^.des]>siz[son[po]] then
               son[po]:=p^.des;
             siz[po]:=siz[po]+siz[p^.des];
           end;
         p:=p^.next;
       end;
   end;//寻找非叶子结点中儿子siz最大,记录在son中

 procedure dfs2(po,tp:longint);
 var
  p:node;
   begin
     inc(tot);
     num[po]:=tot;
     top[po]:=tp;
     if son[po]<>0 then
       dfs2(son[po],tp);

     p:=nd[po];
     while p<>nil do
       begin
         if (p^.des<>son[po]) and (p^.des<>fa[po]) then dfs2(p^.des,p^.des);
         p:=p^.next;
       end;
   end;//将重边练成重链,num记录树上的点哈希到线段树上的结果

 procedure buildtree(l,r:longint);
 var
  t:longint;
   begin
     inc(tot);
     tr[tot].sum:=0;tr[tot].toadd:=0;
     tr[tot].z:=l;tr[tot].y:=r;
     t:=tot;
     if l=r then exit else
       begin
         tr[t].lc:=tot+1;
         buildtree(l,(l+r) div 2);
         tr[t].rc:=tot+1;
         buildtree((l+r) div 2+1,r);
       end;
   end;//建线段树

  procedure add(po,l,r,k:longint);
  var
   mid:longint;
    begin
      if tr[po].toadd<>0 then
        begin
          tr[po].sum:=tr[po].sum+(tr[po].y-tr[po].z+1)*tr[po].toadd;
          tr[tr[po].lc].toadd:=tr[tr[po].lc].toadd+tr[po].toadd;
          tr[tr[po].rc].toadd:=tr[tr[po].rc].toadd+tr[po].toadd;
          tr[po].toadd:=0;
        end;

      mid:=(tr[po].z+tr[po].y) div 2;
      tr[po].sum:=tr[po].sum+(r-l+1)*k;
      if (l=tr[po].z) and (r=tr[po].y) then
        begin
          tr[tr[po].lc].toadd:=tr[tr[po].lc].toadd+k;
          tr[tr[po].rc].toadd:=tr[tr[po].rc].toadd+k;
          exit;
        end else
          begin
            if mid>=l then add(tr[po].lc,l,min(mid,r),k);
            if r>mid then add(tr[po].rc,max(mid+1,l),r,k);
          end;
    end;//线段树加

  function ans(po,l,r:longint):longint;
  var
   mid:longint;
    begin
      if tr[po].toadd<>0 then
        begin
          tr[po].sum:=tr[po].sum+(tr[po].y-tr[po].z+1)*tr[po].toadd;
          tr[tr[po].lc].toadd:=tr[tr[po].lc].toadd+tr[po].toadd;
          tr[tr[po].rc].toadd:=tr[tr[po].rc].toadd+tr[po].toadd;
          tr[po].toadd:=0;
        end;

      mid:=(tr[po].z+tr[po].y) div 2;
      if (l=tr[po].z) and (r=tr[po].y) then
        exit(tr[po].sum) else
          begin
            ans:=0;
            if mid>=l then ans:=ans+ans(tr[po].lc,l,min(mid,r));
            if r>mid then ans:=ans+ans(tr[po].rc,max(mid+1,l),r);
          end;
    end;//线段树求和

  procedure plus(b,c:longint);
    begin
      while top[b]<>top[c] do
        begin
          if dep[top[b]]<dep[top[c]] then
              begin
                add(1,num[top[c]],num[c],1);
                c:=fa[top[c]];
              end
            else
              begin
                add(1,num[top[b]],num[b],1);
                b:=fa[top[b]];
              end;
        end;
      if num[b]<num[c] then add(1,num[b],num[c],1) else add(1,num[c],num[b],1);
    end;//通过重链寻找被修改的区间

  function query(b,c:longint):longint;
    begin
      query:=0;
      while top[b]<>top[c] do
        begin
           if dep[top[b]]<dep[top[c]] then
              begin
                query:=query+ans(1,num[top[c]],num[c]);
                c:=fa[top[c]];
              end
            else
              begin
                query:=query+ans(1,num[top[b]],num[b]);
                b:=fa[top[b]];
              end;
        end;

      if num[b]<num[c] then query:=query+ans(1,num[b],num[c]) else query:=query+ans(1,num[c],num[b]);
    end;//通过重链寻找被求和的区间

  begin

    read(n);

    for i:=1 to n-1 do
      begin
        read(t1,t2);
        new(p);
        p^.des:=t2;p^.next:=nd[t1];nd[t1]:=p;
        new(p);
        p^.des:=t1;p^.next:=nd[t2];nd[t2]:=p;
      end;

    dep[1]:=1;
    dfs1(1);

    dfs2(1,1);

    tot:=0;
    buildtree(1,n);

    read(q);
    for i:=1 to q do
      begin
        read(a,b,c);

        if a=1 then plus(b,c);

        if a=2 then writeln(query(b,c));
      end;
  end.

 ————————————————————————————————————————————————————————————————

 c++(BZOJ1036)

#include <cstdio>
#include <iostream>
#define LL long long
using namespace std;

  int next[60001],des[60001],nd[30001],bt[30001],son[30001],maxi[30001];
  int fa[30001],dep[30001],size[30001],id[30001],top[30001],a[30001],revid[30001];
  int cnt,n,q;
  
  struct node{
      int l,r,lc,rc,maxi,sum;
  }tr[60001];

  void swp(int &x,int &y){
      int t=x;x=y;y=t;
  }

  void addedge(int x,int y){
      next[++cnt]=nd[x];des[cnt]=y;nd[x]=cnt;
      next[++cnt]=nd[y];des[cnt]=x;nd[y]=cnt;
  }
  
  void dfs1(int po){
      bt[po]=1;
      son[po]=-1;maxi[po]=-1;
      size[po]=1;
      for (int p=nd[po];p!=-1;p=next[p])
      if (bt[des[p]]==0){
          fa[des[p]]=po;dep[des[p]]=dep[po]+1;
          dfs1(des[p]);
          size[po]+=size[des[p]];
          if (size[des[p]]>maxi[po]){
            maxi[po]=size[des[p]];
            son[po]=des[p];
          }
      }
  }
  
  void dfs2(int po,int tp){
      id[po]=++cnt;top[po]=tp;
      if (son[po]==-1) return;
      
      dfs2(son[po],tp);
      for (int p=nd[po];p!=-1;p=next[p])
        if(des[p]!=fa[po]&&des[p]!=son[po]) dfs2(des[p],des[p]);
  }
  
  void update(int po){
      tr[po].sum=tr[tr[po].lc].sum+tr[tr[po].rc].sum;
      tr[po].maxi=max(tr[tr[po].lc].maxi,tr[tr[po].rc].maxi);
  }
  
  void build(int l,int r){
      tr[++cnt].l=l;tr[cnt].r=r;
      if (l==r) {tr[cnt].sum=tr[cnt].maxi=a[revid[l]];return;}
      
      int t=cnt,mid=(l+r)>>1;
      tr[t].lc=cnt+1;
      build(l,mid);
      tr[t].rc=cnt+1;
      build(mid+1,r);
      update(t);
  }
  
  void edi(int po,int targ){
      if (tr[po].l==tr[po].r) {tr[po].sum=tr[po].maxi=a[targ];return;}
      
      int mid=(tr[po].l+tr[po].r>>1);
    if (targ<=mid) edi(tr[po].lc,targ);else edi(tr[po].rc,targ);
    update(po);
  }
  
  int getmax(int po,int l,int r){
      if (l==tr[po].l&&r==tr[po].r) return(tr[po].maxi);
      int mid=(tr[po].l+tr[po].r)>>1;
      
      int ret=-1e9;
      if (l<=mid) ret=max(ret,getmax(tr[po].lc,l,min(mid,r)));
      if (r>mid)  ret=max(ret,getmax(tr[po].rc,max(mid+1,l),r));
      return(ret);
  }
  
  void QMAX(int x,int y){
      int ans=-1e9;
      while (top[x]!=top[y]){
        if (dep[top[x]]<dep[top[y]]) swp(x,y);
        ans=max(ans,getmax(1,id[top[x]],id[x]));
        x=fa[top[x]];
    }
    if (dep[x]<dep[y]) swp(x,y);
    ans=max(ans,getmax(1,id[y],id[x]));
    printf("%d\n",ans);
  }
  
  int getsum(int po,int l,int r){
      if (l==tr[po].l&&r==tr[po].r) return(tr[po].sum);
      int mid=(tr[po].l+tr[po].r)>>1;
      
      int ret=0;
      if (l<=mid) ret+=getsum(tr[po].lc,l,min(mid,r));
      if (r>mid) ret+=getsum(tr[po].rc,max(mid+1,l),r);
      return(ret);
  }
  
  void QSUM(int x,int y){
    int ans=0;
      while (top[x]!=top[y]){
        if (dep[top[x]]<dep[top[y]]) swp(x,y);
        ans+=getsum(1,id[top[x]],id[x]);
        x=fa[top[x]];
    }
    if (dep[x]<dep[y]) swp(x,y);
    ans+=getsum(1,id[y],id[x]);
    printf("%d\n",ans);
  }

  int main(){      
      scanf("%d",&n);
      
      for (int i=1;i<=n;i++) nd[i]=-1;
      for (int i=1;i<n;i++){
        int x,y;
        scanf("%d%d",&x,&y);
      addedge(x,y);    
    }
    
    dep[1]=1;
    dfs1(1);
    
    cnt=0;
    dfs2(1,1);
    for (int i=1;i<=n;i++) revid[id[i]]=i;
    
    for (int i=1;i<=n;i++) scanf("%d",&a[i]);
    cnt=0;
    build(1,n);
    
    scanf("%d",&q);
    char st[11];
    for (int i=1;i<=q;i++){
        scanf("%s",&st);
        int x,y;
        scanf("%d%d",&x,&y);
        
        if (st[1]=='M') QMAX(x,y);
        if (st[1]=='S') QSUM(x,y);
        if (st[1]=='H') a[id[x]]=y,edi(1,id[x]);
    }
  }

——————————————————————————————————

 树链剖分可对每条链单独建立线段树以减小常数

#include <cstdio>
#include <iostream>
#define LL long long
using namespace std;

  int next[200001],des[200001],nd[200001],cnt,size[200001],b[200001],fa[200001],dep[200001],son[200001];
  int id[200001],rev[200001],top[200001],n,q,fr[200001],to[200001],root[200001],maxid[200001];
  LL len[200001];
  LL num[200001];

  struct treenode{
      int l,r,lc,rc;
      LL num;
  }tr[200001];

  void addedge(int x,int y,LL num){
      next[++cnt]=nd[x];des[cnt]=y;len[cnt]=num;nd[x]=cnt;
      next[++cnt]=nd[y];des[cnt]=x;len[cnt]=num;nd[y]=cnt;
  }
  
  void dfs1(int po){
      size[po]=1;b[po]=1;
    int maxi=-1;
      for (int p=nd[po];p!=-1;p=next[p])
        if (b[des[p]]==0){
            num[des[p]]=len[p];fa[des[p]]=po;
            dep[des[p]]=dep[po]+1;
          dfs1(des[p]);
        if (size[des[p]]>maxi){
          maxi=size[des[p]];
          son[po]=des[p];    
        }    
        size[po]+=size[des[p]];
      }
  }
  
  void dfs2(int po,int tp){
      id[po]=++cnt;rev[cnt]=po;top[po]=tp;

    if (son[po]) dfs2(son[po],tp);
    for (int p=nd[po];p!=-1;p=next[p])
      if (des[p]!=fa[po]&&des[p]!=son[po])
          dfs2(des[p],des[p]);
  }
  
  void update(LL &a,LL b,LL c){
      if (b==-1||c==-1){
        a=-1;return;    
    }
    if (1e18/b<c){
      a=-1;return;
    }
    a=b*c;
  }
  
  void build(int l,int r){
      tr[++cnt].l=l;tr[cnt].r=r;
      if (l==r){
        tr[cnt].num=num[rev[l]];return;    
    }
    
    int mid=(l+r)>>1,t=cnt;
    tr[t].lc=cnt+1;
     build(l,mid);
    tr[t].rc=cnt+1;
    build(mid+1,r);
    update(tr[t].num,tr[tr[t].lc].num,tr[tr[t].rc].num);
  }
  
  void edi(int po,int tar,LL num){
      if (tr[po].l==tr[po].r) {tr[po].num=num;return;}
      
      int mid=(tr[po].l+tr[po].r)>>1;
      if (tar<=mid) edi(tr[po].lc,tar,num);else
                    edi(tr[po].rc,tar,num);
    update(tr[po].num,tr[tr[po].lc].num,tr[tr[po].rc].num);
  }
  
  LL getnum(int po,int l,int r){
      LL ret=1;
      if (tr[po].l==l&&tr[po].r==r) return(tr[po].num);
      
      int mid=(tr[po].l+tr[po].r)>>1;
      if (l<=mid) update(ret,ret,getnum(tr[po].lc,l,min(mid,r)));
      if (r>mid)  update(ret,ret,getnum(tr[po].rc,max(mid+1,l),r));
      return(ret);
  }
  
  LL query(int x,int y){
      LL ret=1;
      while (top[x]!=top[y]){
        if (dep[top[x]]<dep[top[y]]){
          int t=x;x=y;y=t;    
      }    
      LL t=getnum(root[top[x]],id[top[x]],id[x]);
      update(ret,ret,t);x=fa[top[x]];
    }
    if (dep[x]<dep[y]){
      int t=x;x=y;y=t;
    }
    if (x==y) return(ret);
    LL t=getnum(root[top[x]],id[son[y]],id[x]);
    update(ret,ret,t);
    return(ret);
  }

  int main(){
      scanf("%d%d",&n,&q);
      for (int i=1;i<=n;i++) nd[i]=-1;
      for (int i=1;i<n;i++){
        int t1,t2,t3;
      scanf("%d%d%lld",&fr[i],&to[i],&t3);
      addedge(fr[i],to[i],t3);    
    }
    
    dep[1]=1;
    dfs1(1);
    cnt=0;
    dfs2(1,1);
    cnt=0;
    for (int i=1;i<=n;i++) maxid[top[i]]=max(maxid[top[i]],id[i]);
    for (int i=1;i<=n;i++) if (i==top[i]){
      root[i]=cnt+1;build(id[i],maxid[i]);
    }
    
    for (int i=1;i<=q;i++){
      int typ;
      scanf("%d",&typ);
      
      if (typ==1){
          int x,y;LL v;
          scanf("%d%d%lld",&x,&y,&v);
          LL t=query(x,y);
          if (t==-1) printf("0\n");else printf("%lld\n",v/t);
      }
      
      if (typ==2){
          int li;LL v;
          scanf("%d%lld",&li,&v);
          if (fa[fr[li]]==to[li]){
            int t=fr[li];to[li]=fr[li];fr[li]=t;    
        }
        edi(root[top[to[li]]],id[to[li]],v);
      }
    }
  }

 

posted @ 2016-03-16 13:39  z1j1n1  阅读(291)  评论(0编辑  收藏  举报