BZOJ2243 (树链剖分+线段树)
Problem 染色(BZOJ2243)
题目大意
给定一颗树,每个节点上有一种颜色。
要求支持两种操作:
操作1:将a->b上所有点染成一种颜色。
操作2:询问a->b上的颜色段数量。
解题分析
树链剖分+线段树。
开一个记录类型,记录某一段区间的信息。l 表示区间最左侧的颜色 , r 表示区间最右侧的颜色 , sum 表示区间中颜色段数量。
合并时判断一下左区间的右端点和有区间的左端点的颜色是否一样。
树上合并时需要用两个变量ans1,ans2来存储。ans1表示x往上走时形成的链的信息,ans2表示y往上走时形成链的信息。
当x和y位于同一条重链上时,有三个区间需要合并在一起,注意判断顺序。
参考程序
1 #include <cstdio> 2 #include <cstring> 3 #include <cmath> 4 #include <algorithm> 5 using namespace std; 6 7 #define V 100008 8 #define E 200008 9 #define lson l,m,rt<<1 10 #define rson m+1,r,rt<<1|1 11 12 int n,m,cnt; 13 int a[V],size[V],dep[V],fa[V],son[V],top[V],w[V],rk[V]; 14 15 struct line{ 16 int u,v,nt; 17 }eg[E]; 18 int sum,lt[V]; 19 20 struct color{ 21 int l,r,sum; 22 color(int a=-1,int b=-1,int c=-1):l(a),r(b),sum(c){} 23 }; 24 color merge(color a,color b){ 25 color c; 26 if (a.sum==-1) return b; 27 if (b.sum==-1) return a; 28 if (a.r==b.l){ 29 c.sum=a.sum+b.sum-1; 30 c.l=a.l; 31 c.r=b.r; 32 } 33 else 34 { 35 c.sum=a.sum+b.sum; 36 c.l=a.l; 37 c.r=b.r; 38 } 39 return c; 40 } 41 42 struct segment_tree{ 43 color tag[V<<2]; 44 int lazy[V<<2]; 45 void pushup(int rt){ 46 tag[rt]=merge(tag[rt<<1],tag[rt<<1|1]); 47 } 48 void pushdown(int rt){ 49 if (lazy[rt]){ 50 lazy[rt<<1]=lazy[rt<<1|1]=lazy[rt]; 51 tag[rt<<1].l=tag[rt<<1].r=lazy[rt]; 52 tag[rt<<1|1].l=tag[rt<<1|1].r=lazy[rt]; 53 tag[rt<<1].sum=tag[rt<<1|1].sum=1; 54 lazy[rt]=0; 55 return; 56 } 57 } 58 void build(int l,int r,int rt){ 59 if (l==r){ 60 tag[rt].l=tag[rt].r=a[rk[l]]; 61 tag[rt].sum=1; 62 return; 63 } 64 int m=(l+r)/2; 65 build(lson); 66 build(rson); 67 pushup(rt); 68 } 69 void update(int L,int R,int val,int l,int r,int rt){ 70 if (L<=l && r<=R){ 71 tag[rt].l=tag[rt].r=val; 72 tag[rt].sum=1; 73 lazy[rt]=val; 74 return; 75 } 76 pushdown(rt); 77 int m=(l+r)/2; 78 if (L <= m) update(L,R,val,lson); 79 if (m < R) update(L,R,val,rson); 80 pushup(rt); 81 } 82 color query(int L,int R,int l,int r,int rt){ 83 if (L<=l && r<=R){ 84 return tag[rt]; 85 } 86 pushdown(rt); 87 color res; 88 int m=(l+r)/2; 89 if (L <= m) res=merge(res,query(L,R,lson)); 90 if (m < R) res=merge(res,query(L,R,rson)); 91 return res; 92 } 93 }T; 94 95 void adt(int u,int v){ 96 eg[++sum].u=u; eg[sum].v=v; eg[sum].nt=lt[u]; lt[u]=sum; 97 } 98 void add(int u,int v){ 99 adt(u,v); adt(v,u); 100 } 101 102 void dfs_1(int u){ 103 size[u]=1; dep[u]=dep[fa[u]]+1; son[u]=0; 104 for (int i=lt[u];i;i=eg[i].nt){ 105 int v=eg[i].v; 106 if (v==fa[u]) continue; 107 fa[v]=u; 108 dfs_1(v); 109 size[u]+=size[v]; 110 if (size[v]>size[son[u]]) son[u]=v; 111 } 112 } 113 void dfs_2(int u,int tp){ 114 w[u]=++cnt; top[u]=tp; rk[cnt]=u; 115 if (son[u]) dfs_2(son[u],tp); 116 for (int i=lt[u];i;i=eg[i].nt){ 117 int v=eg[i].v; 118 if (v==fa[u] || v==son[u]) continue; 119 dfs_2(v,v); 120 } 121 } 122 void change(int x,int y,int val){ 123 while (top[x]!=top[y]){ 124 if (dep[top[x]]<dep[top[y]]) swap(x,y); 125 T.update(w[top[x]],w[x],val,1,n,1); 126 x=fa[top[x]]; 127 } 128 if (dep[x]>dep[y]) swap(x,y); 129 T.update(w[x],w[y],val,1,n,1); 130 } 131 void pt(color a){ 132 printf("%d %d %d\n",a.l,a.r,a.sum); 133 } 134 void find(int x,int y){ 135 color ans1,ans2,ans; 136 while (top[x]!=top[y]){ 137 if (dep[top[x]]>dep[top[y]]){ 138 ans1=merge(T.query(w[top[x]],w[x],1,n,1),ans1); 139 x=fa[top[x]]; 140 } 141 else 142 { 143 ans2=merge(T.query(w[top[y]],w[y],1,n,1),ans2); 144 y=fa[top[y]]; 145 } 146 147 } 148 if (dep[x]<dep[y]){ 149 ans=T.query(w[x],w[y],1,n,1); 150 ans=merge(ans,ans2); 151 swap(ans.l,ans.r); 152 ans=merge(ans,ans1); 153 } 154 else 155 { 156 ans=T.query(w[y],w[x],1,n,1); 157 ans=merge(ans,ans1); 158 swap(ans.l,ans.r); 159 ans=merge(ans,ans2); 160 } 161 printf("%d\n",ans.sum ); 162 } 163 164 int main(){ 165 memset(lt,0,sizeof(lt)); sum=1; cnt=0; 166 scanf("%d %d",&n,&m); 167 for (int i=1;i<=n;i++) scanf("%d",&a[i]); 168 for (int i=1;i<n;i++){ 169 int u,v; 170 scanf("%d %d",&u,&v); 171 add(u,v); 172 } 173 dfs_1(1); 174 dfs_2(1,1); 175 T.build(1,n,1); 176 while (m--){ 177 char ch[2]; 178 int x,y,z; 179 scanf("%s",ch); 180 if (ch[0]=='Q'){ 181 scanf("%d %d",&x,&y); 182 find(x,y); 183 } 184 else 185 { 186 scanf("%d %d %d",&x,&y,&z); 187 change(x,y,z); 188 }; 189 } 190 }