【洛谷P3384】树链剖分
这是一道树链剖分的模板题,首先要学会线段树、dfs、链式前向星之类的,不然……//打暴力吧
本题很考验代码能力,难度不大,主要是细节繁多。(我调了1h)
树链剖分的原理不再赘述(详见《信息学奥赛一本通·提高版》),主要说一下一些容易错的细节。
ps:本人树链剖分的写法来源于《信息学奥赛一本通·提高版》,如有漏洞,欢迎指责。
1 #include <iostream> 2 #include <cstdio> 3 #include <cstring> 4 typedef long long ll; 5 inline int read() { 6 int ret=0,f=1; 7 char c=getchar(); 8 while(c<'0'||c>'9') {if(c=='-') f=-1;c=getchar();} 9 while(c<='9'&&c>='0') ret=ret*10+c-'0',c=getchar(); 10 return ret*f; 11 } 12 using namespace std; 13 const int N=100010; 14 int n,m,root,mod; 15 int val[N],fa[N],top[N],size[N],son[N],d[N]; 16 int seg[N],rev[N],sum[N<<2],tag[N<<2]; 17 int ans; 18 struct edge { 19 int next,to; 20 }a[N<<1]; 21 int num,head[N<<1]; 22 inline void add(int from,int to) { 23 a[++num].next=head[from]; a[num].to=to; head[from]=num; 24 swap(from,to); 25 a[++num].next=head[from]; a[num].to=to; head[from]=num; 26 } 27 void dfs1(int u,int f) { 28 d[u]=d[f]+1; 29 size[u]=1; 30 fa[u]=f; 31 for(int i=head[u];i;i=a[i].next) { 32 int v=a[i].to; 33 if(v==f) continue ; 34 dfs1(v,u); 35 size[u]+=size[v]; 36 if(size[v]>size[son[u]]) son[u]=v; 37 } 38 } 39 void dfs2(int u,int f) { 40 if(son[u]) { 41 top[son[u]]=top[u]; 42 seg[son[u]]=++seg[0]; 43 rev[seg[0]]=son[u]; 44 dfs2(son[u],u); 45 } 46 for(int i=head[u];i;i=a[i].next) { 47 int v=a[i].to; 48 if(!top[v]) { 49 top[v]=v; 50 seg[v]=++seg[0]; 51 rev[seg[0]]=v; 52 dfs2(v,u); 53 } 54 } 55 } 56 inline void pushdown(int now,int len) { 57 if(tag[now]) { 58 tag[now<<1]+=tag[now]; 59 tag[now<<1|1]+=tag[now]; 60 sum[now<<1]+=tag[now]*(len-(len>>1)); 61 sum[now<<1|1]+=tag[now]*(len>>1); 62 tag[now]=0; 63 sum[now<<1]%=mod; 64 sum[now<<1|1]%=mod; 65 } 66 } 67 inline void build(int now,int l,int r) { 68 if(l==r) { 69 sum[now]=val[rev[l]]; 70 sum[now]%=mod; 71 return ; 72 } 73 int mid=l+r>>1; 74 build(now<<1,l,mid); 75 build(now<<1|1,mid+1,r); 76 sum[now]=sum[now<<1]+sum[now<<1|1]; 77 sum[now]%=mod; 78 } 79 void updata(int now,int l,int r,int x,int y,int z) { 80 if(x<=l&&r<=y) { 81 tag[now]+=z; 82 sum[now]+=z*(r-l+1); 83 sum[now]%=mod; 84 return ; 85 } 86 pushdown(now,r-l+1); 87 int mid=l+r>>1; 88 if(x<=mid) updata(now<<1,l,mid,x,y,z); 89 if(y>mid) updata(now<<1|1,mid+1,r,x,y,z); 90 sum[now]=sum[now<<1]+sum[now<<1|1]; 91 sum[now]%=mod; 92 } 93 void query(int now,int l,int r,int x,int y) { 94 if(x<=l&&r<=y) { 95 ans+=sum[now]; 96 ans%=mod; 97 return ; 98 } 99 pushdown(now,r-l+1); 100 int mid=l+r>>1; 101 if(x<=mid) query(now<<1,l,mid,x,y); 102 if(y>mid) query(now<<1|1,mid+1,r,x,y); 103 } 104 void find(int x,int y,int z,int op) { 105 while(top[x]!=top[y]) { 106 if(d[top[x]]<d[top[y]]) swap(x,y); 107 if(op==1) updata(1,1,seg[0],seg[top[x]],seg[x],z); 108 else query(1,1,seg[0],seg[top[x]],seg[x]); 109 ans%=mod; 110 x=fa[top[x]]; 111 } 112 if(d[x]>d[y]) swap(x,y); 113 if(op==1) updata(1,1,seg[0],seg[x],seg[y],z); 114 else query(1,1,seg[0],seg[x],seg[y]); 115 ans%=mod; 116 } 117 void find(int x,int y,int op) { 118 if(op==1) updata(1,1,seg[0],seg[x],seg[x]+size[x]-1,y); 119 else query(1,1,seg[0],seg[x],seg[x]+size[x]-1); 120 ans%=mod; 121 } 122 int main() { 123 n=read(); m=read(); root=read(); mod=read(); 124 for(int i=1;i<=n;i++) val[i]=read(); 125 for(int i=1;i<n;i++) add(read(),read()); 126 dfs1(root,0); 127 seg[0]=seg[root]=1; 128 rev[1]=root; 129 top[root]=root; 130 dfs2(root,root); 131 build(1,1,seg[0]); 132 while(m--) { 133 int op=read(); 134 int x,y,z; 135 if(op==1) { 136 x=read(); y=read(); z=read(); 137 find(x,y,z,1); 138 } 139 else if(op==2) { 140 x=read(); y=read(); z=-1; 141 ans=0; 142 find(x,y,z,2); 143 printf("%d\n",ans%mod); 144 ans=0; 145 } 146 else if(op==3) { 147 x=read(); y=read(); 148 find(x,y,1); 149 } 150 else { 151 x=read(); y=-1; 152 ans=0; 153 find(x,y,2); 154 printf("%d\n",ans%mod); 155 ans=0; 156 } 157 } 158 return 0; 159 }
- 代码较长,尽量让代码简洁,明亮,避免调试时出现恶心现象(生理上的不适)
- 注意对答案的取模
- 线段树的细节注意(标记下传)
- 注意区分seg和rev的含义,避免混淆
- dfs2函数的写法(第48行)
- 千万不能让读入优化、建图、dfs这些基本的函数出错,在写的时候慢一点,避免手残。否则欲哭无泪(难以置信自己居然错在这里)……555