【BZOJ1036】树的统计Count(树链剖分,LCT)
题意:一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作:
I. CHANGE u t : 把结点u的权值改为t
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和
注意:从点u到点v的路径上的节点包括u和v本身
1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
思路:树链剖分,单点修改,区间查询和与最大值。
1 var t1,t2,next,vet,head,dep,flag,tid,fa,top,size,id,son,a 2 :array[1..200000]of longint; 3 n,m,i,j,k,k1,t,x,y,tot,time,q,len,f:longint; 4 ch:string; 5 6 procedure add(a,b:longint); 7 begin 8 inc(tot); 9 next[tot]:=head[a]; 10 vet[tot]:=b; 11 head[a]:=tot; 12 end; 13 14 function max(x,y:longint):longint; 15 begin 16 if x>y then exit(x); 17 exit(y); 18 end; 19 20 procedure dfs1(u,fath,depth:longint); 21 var e,v,maxsize:longint; 22 begin 23 flag[u]:=1; dep[u]:=depth; fa[u]:=fath; 24 e:=head[u]; 25 flag[u]:=1; size[u]:=1; 26 maxsize:=0; son[u]:=0; 27 while e<>0 do 28 begin 29 v:=vet[e]; 30 if flag[v]=0 then 31 begin 32 // h[e]:=1; 33 dfs1(v,u,depth+1); 34 size[u]:=size[u]+size[v]; 35 if size[v]>maxsize then 36 begin 37 maxsize:=size[v]; 38 son[u]:=v; 39 end; 40 end; 41 e:=next[e]; 42 end; 43 end; 44 45 procedure dfs2(u,ance:longint); 46 var e,v:longint; 47 begin 48 flag[u]:=1; inc(time); tid[u]:=time; id[time]:=u; top[u]:=ance; 49 if son[u]>0 then dfs2(son[u],ance); 50 e:=head[u]; 51 while e<>0 do 52 begin 53 v:=vet[e]; 54 if flag[v]=0 then dfs2(v,v); 55 e:=next[e]; 56 end; 57 end; 58 59 procedure build(l,r,p:longint); 60 var mid:longint; 61 begin 62 if l=r then 63 begin 64 t1[p]:=a[id[l]]; 65 t2[p]:=a[id[l]]; 66 exit; 67 end; 68 mid:=(l+r)>>1; 69 build(l,mid,p<<1); 70 build(mid+1,r,p<<1+1); 71 t1[p]:=max(t1[p<<1],t1[p<<1+1]); 72 t2[p]:=t2[p<<1]+t2[p<<1+1]; 73 end; 74 75 procedure update(l,r,x,v,p:longint); 76 var mid:longint; 77 begin 78 if l=r then 79 begin 80 t1[p]:=v; t2[p]:=v; 81 exit; 82 end; 83 mid:=(l+r)>>1; 84 if x<=mid then update(l,mid,x,v,p<<1); 85 if x>mid then update(mid+1,r,x,v,p<<1+1); 86 t1[p]:=max(t1[p<<1],t1[p<<1+1]); 87 t2[p]:=t2[p<<1]+t2[p<<1+1]; 88 end; 89 90 function querymax(l,r,x,y,p:longint):longint; 91 var mid,t:longint; 92 begin 93 if (l=x)and(r=y) then exit(t1[p]); 94 mid:=(l+r)>>1; 95 t:=-maxlongint; 96 if y<=mid then t:=querymax(l,mid,x,y,p<<1) 97 else if x>mid then t:=querymax(mid+1,r,x,y,p<<1+1) 98 else t:=max(querymax(l,mid,x,mid,p<<1),querymax(mid+1,r,mid+1,y,p<<1+1)); 99 exit(t); 100 end; 101 102 function querysum(l,r,x,y,p:longint):longint; 103 var mid,t:longint; 104 begin 105 if (l=x)and(r=y) then exit(t2[p]); 106 mid:=(l+r)>>1; 107 t:=0; 108 if y<=mid then t:=querysum(l,mid,x,y,p<<1) 109 else if x>mid then t:=querysum(mid+1,r,x,y,p<<1+1) 110 else t:=querysum(l,mid,x,mid,p<<1)+querysum(mid+1,r,mid+1,y,p<<1+1); 111 exit(t); 112 end; 113 114 procedure swap(var x,y:longint); 115 var t:longint; 116 begin 117 t:=x; x:=y; y:=t; 118 end; 119 120 function askmax(x,y:longint):longint; 121 var f1,f2,t:longint; 122 begin 123 t:=-maxlongint; 124 f1:=top[x]; f2:=top[y]; 125 while f1<>f2 do 126 begin 127 if dep[f1]<dep[f2] then 128 begin 129 swap(f1,f2); swap(x,y); 130 end; 131 t:=max(t,querymax(1,n,tid[f1],tid[x],1)); 132 x:=fa[f1]; f1:=top[x]; 133 end; 134 if dep[x]>dep[y] then swap(x,y); 135 t:=max(t,querymax(1,n,tid[x],tid[y],1)); 136 exit(t); 137 end; 138 139 function asksum(x,y:longint):longint; 140 var f1,f2,t:longint; 141 begin 142 t:=0; 143 f1:=top[x]; f2:=top[y]; 144 while f1<>f2 do 145 begin 146 if dep[f1]<dep[f2] then 147 begin 148 swap(f1,f2); swap(x,y); 149 end; 150 t:=t+querysum(1,n,tid[f1],tid[x],1); 151 x:=fa[f1]; f1:=top[x]; 152 end; 153 if dep[x]>dep[y] then swap(x,y); 154 t:=t+querysum(1,n,tid[x],tid[y],1); 155 exit(t); 156 end; 157 158 begin 159 //assign(input,'bzoj1036.in'); reset(input); 160 //assign(output,'bzoj1036.out'); rewrite(output); 161 readln(n); 162 for i:=1 to n-1 do 163 begin 164 readln(x,y); 165 add(x,y); 166 add(y,x); 167 end; 168 for i:=1 to n do read(a[i]); 169 dfs1(1,-1,1); 170 fillchar(flag,sizeof(flag),0); 171 dfs2(1,1); 172 fillchar(t1,sizeof(t1),$8f); 173 build(1,n,1); 174 readln(q); 175 for i:=1 to q do 176 begin 177 readln(ch); 178 x:=0; y:=0; 179 for j:=1 to length(ch) do 180 if ch[j]=' ' then break; 181 while ch[j]=' ' do inc(j); 182 f:=1; 183 while ch[j]<>' ' do 184 begin 185 // if ch[j]='-' then f:=-1; 186 if (ch[j]>='0')and(ch[j]<='9') then x:=x*10+ord(ch[j])-ord('0'); 187 inc(j); 188 end; 189 190 while ch[j]=' ' do inc(j); 191 while (j<=length(ch))and(ch[j]<>' ') do 192 begin 193 if ch[j]='-' then f:=-1; 194 if (ch[j]>='0')and(ch[j]<='9') then y:=y*10+ord(ch[j])-ord('0'); 195 inc(j); 196 end; 197 if f=-1 then y:=-y; 198 199 200 201 if ch[1]='C' then update(1,n,tid[x],y,1); 202 if (ch[1]='Q')and(ch[2]='M') then writeln(askmax(x,y)); 203 if (ch[1]='Q')and(ch[2]='S') then writeln(asksum(x,y)); 204 { if ch[1]='C' then 205 begin 206 delete(ch,1,7); 207 val(copy(ch,1,pos(' ',ch)-1),x); 208 val(copy(ch,pos(' ',ch)+1,length(ch)-pos(' ',ch)),y); 209 update(1,n,tid[x],y,1); 210 end 211 else if ch[2]='M' then 212 begin 213 delete(ch,1,5); 214 val(copy(ch,1,pos(' ',ch)-1),x); 215 val(copy(ch,pos(' ',ch)+1,length(ch)-pos(' ',ch)),y); 216 writeln(askmax(x,y)); 217 end 218 else 219 begin 220 delete(ch,1,5); 221 val(copy(ch,1,pos(' ',ch)-1),x); 222 val(copy(ch,pos(' ',ch)+1,length(ch)-pos(' ',ch)),y); 223 writeln(asksum(x,y)); 224 225 end; } 226 227 end; 228 //close(input); 229 //close(output); 230 end.
这是LCT写法,这种没有结构变化的树上操作还是写树剖吧,BZOJ上LCT只能卡着时限过
1 var c:array[0..50000,0..1]of longint; 2 mx,sum,w:array[0..50000]of int64; 3 fa,rev,q,a,b:array[0..50000]of longint; 4 n,m,i,x,y,k,que,j,f,top,t:longint; 5 ch:string; 6 7 procedure update(x:longint); 8 var l,r:longint; 9 begin 10 l:=c[x,0]; r:=c[x,1]; 11 sum[x]:=sum[l]+sum[r]+w[x]; 12 mx[x]:=w[x]; 13 if mx[l]>mx[x] then mx[x]:=mx[l]; 14 if mx[r]>mx[x] then mx[x]:=mx[r]; 15 //mx[x]:=max(w[x],max(mx[l],mx[r])); 16 end; 17 18 function isroot(x:longint):boolean; 19 begin 20 if (c[fa[x],0]<>x)and(c[fa[x],1]<>x) then exit(true); 21 exit(false); 22 end; 23 24 25 26 procedure pushdown(x:longint); 27 var l,r:longint; 28 begin 29 l:=c[x,0]; r:=c[x,1]; 30 if rev[x]=1 then 31 begin 32 rev[x]:=rev[x] xor 1; rev[l]:=rev[l] xor 1; rev[r]:=rev[r] xor 1; 33 //swap(c[x,0],c[x,1]); 34 t:=c[x,0]; c[x,0]:=c[x,1]; c[x,1]:=t; 35 end; 36 end; 37 38 procedure rotate(x:longint); 39 var y,z,l,r:longint; 40 begin 41 y:=fa[x]; z:=fa[y]; 42 if c[y,1]=x then l:=1 43 else l:=0; r:=1-l; 44 if not((c[fa[y],0]<>y)and(c[fa[y],1]<>y)) then 45 if c[z,1]=y then c[z,1]:=x 46 else c[z,0]:=x; 47 fa[c[x,r]]:=y; fa[y]:=x; fa[x]:=z; 48 c[y,l]:=c[x,r]; c[x,r]:=y; 49 update(y); 50 update(x); 51 end; 52 53 procedure splay(x:longint); 54 var k,y,z,i:longint; 55 begin 56 inc(top); q[top]:=x; 57 k:=x; 58 while not((c[fa[k],0]<>k)and(c[fa[k],1]<>k)) do 59 begin 60 inc(top); q[top]:=fa[k]; 61 k:=fa[k]; 62 end; 63 while top>0 do 64 begin 65 pushdown(q[top]); 66 dec(top); 67 end; 68 69 while not((c[fa[x],0]<>x)and(c[fa[x],1]<>x)) do 70 begin 71 y:=fa[x]; z:=fa[y]; 72 if not((c[fa[y],0]<>y)and(c[fa[y],1]<>y)) then 73 begin 74 if (c[y,0]=x)xor(c[z,0]=y) then rotate(x) 75 else rotate(y); 76 end; 77 rotate(x); 78 end; 79 end; 80 81 procedure access(x:longint); 82 var t:longint; 83 begin 84 t:=0; 85 while x>0 do 86 begin 87 splay(x); c[x,1]:=t; update(x); 88 t:=x; x:=fa[x]; 89 end; 90 end; 91 92 procedure makeroot(x:longint); 93 begin 94 access(x); splay(x); rev[x]:=rev[x] xor 1; 95 end; 96 97 procedure link(x,y:longint); 98 begin 99 makeroot(x); fa[x]:=y; 100 end; 101 102 procedure split(x,y:longint); 103 begin 104 makeroot(x); access(y); splay(y); 105 end; 106 107 begin 108 assign(input,'bzoj1036.in'); reset(input); 109 assign(output,'bzoj1036.out'); rewrite(output); 110 readln(n); 111 mx[0]:=-maxlongint div 2; 112 113 for i:=1 to n-1 do readln(a[i],b[i]); 114 for i:=1 to n do 115 begin 116 read(w[i]); sum[i]:=w[i]; mx[i]:=w[i]; 117 end; 118 for i:=1 to n-1 do link(a[i],b[i]); 119 readln(que); 120 for i:=1 to que do 121 begin 122 readln(ch); 123 x:=0; y:=0; 124 for j:=1 to length(ch) do 125 if ch[j]=' ' then break; 126 while ch[j]=' ' do inc(j); 127 f:=1; 128 while ch[j]<>' ' do 129 begin 130 if (ch[j]>='0')and(ch[j]<='9') then x:=x*10+ord(ch[j])-ord('0'); 131 inc(j); 132 end; 133 134 while ch[j]=' ' do inc(j); 135 while (j<=length(ch))and(ch[j]<>' ') do 136 begin 137 if ch[j]='-' then f:=-1; 138 if (ch[j]>='0')and(ch[j]<='9') then y:=y*10+ord(ch[j])-ord('0'); 139 inc(j); 140 end; 141 if f=-1 then y:=-y; 142 143 144 145 if ch[1]='C' then 146 begin 147 splay(x); 148 w[x]:=y; 149 update(x); 150 end; 151 if (ch[1]='Q')and(ch[2]='M') then 152 begin 153 split(x,y); 154 writeln(mx[y]); 155 end; 156 157 if (ch[1]='Q')and(ch[2]='S') then 158 begin 159 split(x,y); 160 writeln(sum[y]); 161 end; 162 163 end; 164 close(input); 165 close(output); 166 end.
UPD(2018.9.19):C++ 树链剖分写法
1 #include<cstdio> 2 #include<cstring> 3 #include<string> 4 #include<cmath> 5 #include<iostream> 6 #include<algorithm> 7 #include<map> 8 #include<set> 9 #include<queue> 10 #include<vector> 11 using namespace std; 12 typedef long long ll; 13 typedef unsigned int uint; 14 typedef unsigned long long ull; 15 typedef pair<int,int> PII; 16 typedef vector<int> VI; 17 #define fi first 18 #define se second 19 #define MP make_pair 20 #define N 210000 21 #define MOD 1000000007 22 #define eps 1e-8 23 #define pi acos(-1) 24 #define oo 1e9 25 26 int mx[N],sum[N],a[N],head[N],vet[N],nxt[N],top[N],tid[N],id[N], 27 fa[N],size[N],son[N],dep[N],flag[N],n,cnt,tot; 28 char ch[10]; 29 30 int read() 31 { 32 int v=0,f=1; 33 char c=getchar(); 34 while(c<48||57<c) {if(c=='-') f=-1; c=getchar();} 35 while(48<=c&&c<=57) v=(v<<3)+v+v+c-48,c=getchar(); 36 return v*f; 37 } 38 39 void add(int a,int b) 40 { 41 nxt[++tot]=head[a]; 42 vet[tot]=b; 43 head[a]=tot; 44 } 45 46 void dfs1(int u) 47 { 48 flag[u]=1; size[u]=1; 49 int maxsize=0; son[u]=0; 50 int e=head[u]; 51 while(e) 52 { 53 int v=vet[e]; 54 if(!flag[v]) 55 { 56 fa[v]=u; 57 dep[v]=dep[u]+1; 58 dfs1(v); 59 size[u]+=size[v]; 60 if(size[v]>maxsize) 61 { 62 maxsize=size[v]; 63 son[u]=v; 64 } 65 } 66 e=nxt[e]; 67 } 68 } 69 70 void dfs2(int u,int ance) 71 { 72 flag[u]=1; 73 tid[u]=++cnt; id[cnt]=u; top[u]=ance; 74 if(son[u]) dfs2(son[u],ance); 75 int e=head[u]; 76 while(e) 77 { 78 int v=vet[e]; 79 if(!flag[v]) dfs2(v,v); 80 e=nxt[e]; 81 } 82 } 83 84 void pushup(int p) 85 { 86 mx[p]=max(mx[p<<1],mx[p<<1|1]); 87 sum[p]=sum[p<<1]+sum[p<<1|1]; 88 } 89 90 void build(int l,int r,int p) 91 { 92 if(l==r) 93 { 94 mx[p]=sum[p]=a[id[l]]; 95 return; 96 } 97 int mid=(l+r)>>1; 98 build(l,mid,p<<1); 99 build(mid+1,r,p<<1|1); 100 pushup(p); 101 } 102 103 void update(int l,int r,int x,int v,int p) 104 { 105 if(l==r) 106 { 107 mx[p]=sum[p]=v; 108 return; 109 } 110 int mid=(l+r)>>1; 111 if(x<=mid) update(l,mid,x,v,p<<1); 112 else update(mid+1,r,x,v,p<<1|1); 113 pushup(p); 114 } 115 116 int querymax(int l,int r,int x,int y,int p) 117 { 118 if(x<=l&&r<=y) return mx[p]; 119 int mid=(l+r)>>1; 120 int ans=-oo; 121 if(x<=mid) ans=max(ans,querymax(l,mid,x,y,p<<1)); 122 if(y>mid) ans=max(ans,querymax(mid+1,r,x,y,p<<1|1)); 123 return ans; 124 } 125 126 int querysum(int l,int r,int x,int y,int p) 127 { 128 if(x<=l&&r<=y) return sum[p]; 129 int mid=(l+r)>>1; 130 int ans=0; 131 if(x<=mid) ans+=querysum(l,mid,x,y,p<<1); 132 if(y>mid) ans+=querysum(mid+1,r,x,y,p<<1|1); 133 return ans; 134 } 135 136 int qmax(int x,int y) 137 { 138 int ans=-oo; 139 int f1=top[x]; 140 int f2=top[y]; 141 while(f1!=f2) 142 { 143 if(dep[f1]<dep[f2]) 144 { 145 swap(f1,f2); 146 swap(x,y); 147 } 148 ans=max(ans,querymax(1,n,tid[f1],tid[x],1)); 149 x=fa[f1]; 150 f1=top[x]; 151 } 152 if(dep[x]>dep[y]) swap(x,y); 153 ans=max(ans,querymax(1,n,tid[x],tid[y],1)); 154 return ans; 155 } 156 157 int qsum(int x,int y) 158 { 159 int ans=0; 160 int f1=top[x]; 161 int f2=top[y]; 162 while(f1!=f2) 163 { 164 if(dep[f1]<dep[f2]) 165 { 166 swap(f1,f2); 167 swap(x,y); 168 } 169 ans+=querysum(1,n,tid[f1],tid[x],1); 170 x=fa[f1]; 171 f1=top[x]; 172 } 173 if(dep[x]>dep[y]) swap(x,y); 174 ans+=querysum(1,n,tid[x],tid[y],1); 175 return ans; 176 } 177 178 int main() 179 { 180 freopen("bzoj1036.in","r",stdin); 181 freopen("bzoj1036.out","w",stdout); 182 scanf("%d",&n); 183 for(int i=1;i<=n-1;i++) 184 { 185 int x,y; 186 scanf("%d%d",&x,&y); 187 add(x,y); 188 add(y,x); 189 } 190 for(int i=1;i<=n;i++) scanf("%d",&a[i]); 191 192 cnt=0; 193 dfs1(1); 194 memset(flag,0,sizeof(flag)); 195 dfs2(1,1); 196 for(int i=0;i<=N-1;i++) mx[i]=-oo; 197 build(1,n,1); 198 int q; 199 scanf("%d",&q); 200 //q=0; 201 for(int i=1;i<=q;i++) 202 { 203 int x,y; 204 scanf("%s%d%d",ch,&x,&y); 205 if(ch[0]=='C') update(1,n,tid[x],y,1); 206 if(ch[1]=='M') 207 { 208 int ans=qmax(x,y); 209 printf("%d\n",ans); 210 } 211 if(ch[1]=='S') 212 { 213 int ans=qsum(x,y); 214 printf("%d\n",ans); 215 } 216 } 217 } 218 219 220 221
UPD(2018.9.20):C++ LCT写法
1 #include<cstdio> 2 #include<cstring> 3 #include<string> 4 #include<cmath> 5 #include<iostream> 6 #include<algorithm> 7 #include<map> 8 #include<set> 9 #include<queue> 10 #include<vector> 11 using namespace std; 12 typedef long long ll; 13 typedef unsigned int uint; 14 typedef unsigned long long ull; 15 typedef pair<int,int> PII; 16 typedef vector<int> VI; 17 #define fi first 18 #define se second 19 #define MP make_pair 20 #define N 210000 21 #define MOD 1000000007 22 #define eps 1e-8 23 #define pi acos(-1) 24 #define oo 1e9 25 26 char ch[10]; 27 int t[N][2],fa[N],a[N],b[N],q[N],rev[N],n,top; 28 ll w[N],sum[N],mx[N]; 29 30 int read() 31 { 32 int v=0,f=1; 33 char c=getchar(); 34 while(c<48||57<c) {if(c=='-') f=-1; c=getchar();} 35 while(48<=c&&c<=57) v=(v<<3)+v+v+c-48,c=getchar(); 36 return v*f; 37 } 38 39 bool isroot(int x) 40 { 41 return t[fa[x]][0]!=x&&t[fa[x]][1]!=x; 42 } 43 44 void pushup(int x) 45 { 46 int l=t[x][0]; 47 int r=t[x][1]; 48 sum[x]=sum[l]+sum[r]+w[x]; 49 mx[x]=max(w[x],max(mx[l],mx[r])); 50 } 51 52 void pushdown(int x) 53 { 54 int l=t[x][0]; 55 int r=t[x][1]; 56 if(rev[x]) 57 { 58 rev[x]^=1; 59 rev[l]^=1; 60 rev[r]^=1; 61 swap(t[x][0],t[x][1]); 62 } 63 } 64 65 void rotate(int x) 66 { 67 int y=fa[x]; 68 int z=fa[y]; 69 int l=(t[y][1]==x); 70 int r=l^1; 71 if(!isroot(y)) t[z][t[z][1]==y]=x; 72 fa[t[x][r]]=y; fa[y]=x; fa[x]=z; 73 t[y][l]=t[x][r]; t[x][r]=y; 74 pushup(y); 75 pushup(x); 76 } 77 78 void splay(int x) 79 { 80 q[++top]=x; 81 for(int i=x;!isroot(i);i=fa[i]) q[++top]=fa[i]; 82 while(top) pushdown(q[top--]); 83 while(!isroot(x)) 84 { 85 int y=fa[x]; 86 int z=fa[y]; 87 if(!isroot(y)) 88 { 89 if(t[y][0]==x^t[z][0]==y) rotate(x); 90 else rotate(y); 91 } 92 rotate(x); 93 } 94 } 95 96 void access(int x) 97 { 98 for(int k=0;x;k=x,x=fa[x]) 99 { 100 splay(x); 101 t[x][1]=k; 102 pushup(x); 103 } 104 } 105 106 void makeroot(int x) 107 { 108 access(x); 109 splay(x); 110 rev[x]^=1; 111 } 112 113 void link(int x,int y) 114 { 115 makeroot(x); 116 fa[x]=y; 117 } 118 119 void split(int x,int y) 120 { 121 makeroot(x); 122 access(y); 123 splay(y); 124 } 125 126 int main() 127 { 128 //freopen("bzoj1036.in","r",stdin); 129 //freopen("bzoj1036.out","w",stdout); 130 int n; 131 scanf("%d",&n); 132 mx[0]=-oo; 133 for(int i=1;i<=n-1;i++) scanf("%d%d",&a[i],&b[i]); 134 for(int i=1;i<=n;i++) 135 { 136 scanf("%lld",&w[i]); 137 mx[i]=sum[i]=w[i]; 138 } 139 for(int i=1;i<=n-1;i++) link(a[i],b[i]); 140 int q; 141 scanf("%d",&q); 142 for(int i=1;i<=q;i++) 143 { 144 int x,y; 145 scanf("%s%d%d",ch,&x,&y); 146 if(ch[0]=='C') 147 { 148 splay(x); 149 w[x]=y; 150 pushup(x); 151 } 152 if(ch[1]=='M') 153 { 154 split(x,y); 155 ll ans=mx[y]; 156 printf("%lld\n",ans); 157 } 158 if(ch[1]=='S') 159 { 160 split(x,y); 161 ll ans=sum[y]; 162 printf("%lld\n",ans); 163 } 164 } 165 }
null