bzoj 2243 2243: [SDOI2011]染色 树链剖分
直接树链剖分就可以啦。
1 #include <cstdio> 2 #include <cstring> 3 #include <cmath> 4 #include <algorithm> 5 using namespace std; 6 const int MAXN = 110000; 7 typedef int ll; 8 int head[MAXN],to[MAXN * 2],nxt[MAXN * 2],q[MAXN],dep[MAXN],son[MAXN],siz[MAXN],top[MAXN],pos[MAXN],idx[MAXN],fa[MAXN],jl[MAXN]; 9 int data[MAXN * 5],lcol[MAXN * 5],rcol[MAXN * 5],lzy[MAXN * 5],p[MAXN][32]; 10 int n,m,a,b,c,sum,cnt; 11 int col[MAXN]; 12 char s[10]; 13 void add(int a,int b) 14 { 15 nxt[++cnt] = head[a]; 16 to[cnt] = b; 17 head[a] = cnt; 18 } 19 void updata(int k) 20 { 21 if (rcol[k << 1] == lcol[k << 1 | 1]) 22 data[k] = data[k << 1] + data[k << 1 | 1] - 1; 23 else 24 data[k] = data[k << 1] + data[k << 1 | 1]; 25 lcol[k] = lcol[k << 1]; 26 rcol[k] = rcol[k << 1 | 1]; 27 } 28 void build(int k,int l,int r) 29 { 30 lzy[k] = -1; 31 if (l == r) 32 { 33 data[k] = 1; 34 lcol[k] = rcol[k] = col[idx[l]]; 35 return; 36 } 37 int midn = l + r >> 1; 38 build(k << 1,l,midn); 39 build(k << 1 | 1,midn + 1,r); 40 updata(k); 41 } 42 void lca_init() 43 { 44 for (int j = 1;(1 << j) <= n;j++) 45 for (int i = 1;i <= n;i++) 46 p[i][j] = p[p[i][j - 1]][j - 1]; 47 } 48 void init() 49 { 50 int que = 0,u,v; 51 q[++que] = 1; 52 dep[1] = 1; 53 for (int i = q[que];i <= que;i++) 54 { 55 siz[u = q[i]] = 1; 56 for (int o = head[u];o;o = nxt[o]) 57 { 58 v = to[o]; 59 if (v == fa[u]) 60 continue; 61 fa[v] = u; 62 p[v][0] = u; 63 dep[v] = dep[u] + 1; 64 q[++que] = v; 65 } 66 } 67 for (int i = que;i >= 1;i--) 68 { 69 u = q[i]; 70 siz[v = fa[u]] += siz[u]; 71 if (siz[u] > siz[son[v]]) 72 son[v] = u; 73 } 74 for (int i = 1;i <= que;i++) 75 { 76 if (top[u = q[i]]) 77 continue; 78 for (v = u;v;v = son[v]) 79 { 80 top[v] = u; 81 idx[pos[v] = ++sum] = v; 82 } 83 } 84 build(1,1,sum); 85 lca_init(); 86 } 87 int lca(int a,int b) 88 { 89 if (dep[a] < dep[b]) swap(a,b); 90 for (int j = log2(n);j >= 0;j--) 91 if (dep[a] - (1 << j) >= dep[b]) 92 a = p[a][j]; 93 if (a == b) return b; 94 for (int j = log2(n);j >= 0;j--) 95 if (p[a][j] != p[b][j]) 96 { 97 a = p[a][j]; 98 b = p[b][j]; 99 } 100 return p[a][0]; 101 } 102 void push_down(int k) 103 { 104 data[k << 1] = data[k << 1 | 1] = 1; 105 lzy[k << 1] = lzy[k << 1 | 1] = lzy[k]; 106 lcol[k << 1] = rcol[k << 1] = lcol[k << 1 | 1] = rcol[k << 1 | 1] = lzy[k]; 107 lzy[k] = -1; 108 } 109 void vec_chg(int k,int l,int r,int x,int y,int xg) 110 { 111 if (x <= l && r <= y) 112 { 113 lzy[k] = xg; 114 lcol[k] = rcol[k] = xg; 115 data[k] = 1; 116 return; 117 } 118 if (lzy[k] >= 0) push_down(k); 119 int midn = l + r >> 1; 120 if (x <= midn) vec_chg(k << 1,l,midn,x,y,xg); 121 if (y > midn) vec_chg(k << 1 | 1,midn + 1,r,x,y,xg); 122 updata(k); 123 } 124 int vec_qry(int k,int l,int r,int x,int y) 125 { 126 if (x <= l && r <= y) return data[k]; 127 if (lzy[k] >= 0) push_down(k); 128 int midn = l + r >> 1,res = 0; 129 if (x <= midn) res += vec_qry(k << 1,l,midn,x,y); 130 if (y > midn) res += vec_qry(k << 1 | 1,midn + 1,r,x,y); 131 if (x <= midn && y > midn) if (rcol[k << 1] == lcol[k << 1 | 1]) res--; 132 return res; 133 } 134 int vec_col(int k,int l,int r,int x,int y) 135 { 136 if (x <= l && r <= y) return lcol[k]; 137 if (lzy[k] >= 0) push_down(k); 138 int midn = l + r >> 1; 139 if (x <= midn) return vec_col(k << 1,l,midn,x,y); 140 if (y > midn) return vec_col(k << 1 | 1,midn + 1,r,x,y); 141 } 142 void query_chg(int a,int b,int xg) 143 { 144 while (top[a] != top[b]) 145 { 146 if (dep[top[a]] < dep[top[b]]) swap(a,b); 147 vec_chg(1,1,sum,pos[top[a]],pos[a],xg); 148 a = fa[top[a]]; 149 } 150 if (dep[a] > dep[b]) swap(a,b); 151 vec_chg(1,1,sum,pos[a],pos[b],xg); 152 } 153 int query_sum(int a,int b) 154 { 155 int res = 0; 156 while (top[a] != top[b]) 157 { 158 if (dep[top[a]] < dep[top[b]]) swap(a,b); 159 res += vec_qry(1,1,sum,pos[top[a]],pos[a]); 160 if (vec_col(1,1,sum,pos[top[a]],pos[top[a]]) == vec_col(1,1,sum,pos[fa[top[a]]],pos[fa[top[a]]])) res--; 161 a = fa[top[a]]; 162 } 163 if (dep[a] > dep[b]) swap(a,b); 164 res += vec_qry(1,1,sum,pos[a],pos[b]); 165 return res; 166 } 167 int main() 168 { 169 scanf("%d%d",&n,&m); 170 for (int i = 1;i <= n;i++) 171 scanf("%d",&col[i]); 172 for (int i = 1;i <= n - 1;i++) 173 { 174 scanf("%d%d",&a,&b); 175 add(a,b); 176 add(b,a); 177 } 178 init(); 179 for (int i = 1;i <= m;i++) 180 { 181 scanf("%s",s); 182 if (s[0] == 'C') 183 { 184 scanf("%d%d%d",&a,&b,&c); 185 query_chg(a,b,c); 186 }else 187 { 188 scanf("%d%d",&a,&b); 189 int tt = lca(a,b); 190 printf("%d\n",query_sum(a,tt) + query_sum(b,tt) - 1); 191 } 192 } 193 return 0; 194 }
心之所动 且就随缘去吧