SPOJ COT2 Count on a tree II(树上莫队)
题目链接:http://www.spoj.com/problems/COT2/
You are given a tree with N nodes.The tree nodes are numbered from 1 to N.Each node has an integer weight.
We will ask you to perfrom the following operation:
- u v : ask for how many different integers that represent the weight of nodes there are on the path from u to v.
Input
In the first line there are two integers N and M.(N<=40000,M<=100000)
In the second line there are N integers.The ith integer denotes the weight of the ith node.
In the next N-1 lines,each line contains two integers u v,which describes an edge (u,v).
In the next M lines,each line contains two integers u v,which means an operation asking for how many different integers that represent the weight of nodes there are on the path from u to v.
Output
For each operation,print its result.
题目大意:给一棵树,每个点有一个权值。多次询问路径(a, b)上有多少个权值不同的点。
思路:参考VFK WC 2013 糖果公园 park 题解(此题比COT2要难。)
http://vfleaking.blog.163.com/blog/static/174807634201311011201627/
代码(2.37S):
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 const int MAXV = 40010; 5 const int MAXE = MAXV << 1; 6 const int MAXQ = 100010; 7 const int MLOG = 20; 8 9 namespace Bilibili { 10 11 int head[MAXV], val[MAXV], ecnt; 12 int to[MAXE], next[MAXE]; 13 int n, m; 14 15 int stk[MAXV], top; 16 int block[MAXV], bcnt, bsize; 17 18 struct Query { 19 int u, v, id; 20 void read(int i) { 21 id = i; 22 scanf("%d%d", &u, &v); 23 } 24 void adjust() { 25 if(block[u] > block[v]) swap(u, v); 26 } 27 bool operator < (const Query &rhs) const { 28 if(block[u] != block[rhs.u]) return block[u] < block[rhs.u]; 29 return block[v] < block[rhs.v]; 30 } 31 } ask[MAXQ]; 32 int ans[MAXQ]; 33 /// Graph 34 void init() { 35 memset(head + 1, -1, n * sizeof(int)); 36 ecnt = 0; 37 } 38 39 void add_edge(int u, int v) { 40 to[ecnt] = v; next[ecnt] = head[u]; head[u] = ecnt++; 41 to[ecnt] = u; next[ecnt] = head[v]; head[v] = ecnt++; 42 } 43 44 void gethash(int a[], int n) { 45 static int tmp[MAXV]; 46 int cnt = 0; 47 for(int i = 1; i <= n; ++i) tmp[cnt++] = a[i]; 48 sort(tmp, tmp + cnt); 49 cnt = unique(tmp, tmp + cnt) - tmp; 50 for(int i = 1; i <= n; ++i) 51 a[i] = lower_bound(tmp, tmp + cnt, a[i]) - tmp + 1; 52 } 53 54 void read() { 55 scanf("%d%d", &n, &m); 56 for(int i = 1; i <= n; ++i) scanf("%d", &val[i]); 57 gethash(val, n); 58 init(); 59 for(int i = 1, u, v; i < n; ++i) { 60 scanf("%d%d", &u, &v); 61 add_edge(u, v); 62 } 63 for(int i = 0; i < m; ++i) ask[i].read(i); 64 } 65 /// find_block 66 void add_block(int &cnt) { 67 while(cnt--) block[stk[--top]] = bcnt; 68 bcnt++; 69 cnt = 0; 70 } 71 72 void rest_block() { 73 while(top) block[stk[--top]] = bcnt - 1; 74 } 75 76 int dfs_block(int u, int f) { 77 int size = 0; 78 for(int p = head[u]; ~p; p = next[p]) { 79 int v = to[p]; 80 if(v == f) continue; 81 size += dfs_block(v, u); 82 if(size >= bsize) add_block(size); 83 } 84 stk[top++] = u; 85 size++; 86 if(size >= bsize) add_block(size); 87 return size; 88 } 89 90 void init_block() { 91 bsize = max(1, (int)sqrt(n)); 92 dfs_block(1, 0); 93 rest_block(); 94 } 95 /// ask_rmq 96 int fa[MLOG][MAXV]; 97 int dep[MAXV]; 98 99 void dfs_lca(int u, int f, int depth) { 100 dep[u] = depth; 101 fa[0][u] = f; 102 for(int p = head[u]; ~p; p = next[p]) { 103 int v = to[p]; 104 if(v != f) dfs_lca(v, u, depth + 1); 105 } 106 } 107 108 void init_lca() { 109 dfs_lca(1, -1, 0); 110 for(int k = 0; k + 1 < MLOG; ++k) { 111 for(int u = 1; u <= n; ++u) { 112 if(fa[k][u] == -1) fa[k + 1][u] = -1; 113 else fa[k + 1][u] = fa[k][fa[k][u]]; 114 } 115 } 116 } 117 118 int ask_lca(int u, int v) { 119 if(dep[u] < dep[v]) swap(u, v); 120 for(int k = 0; k < MLOG; ++k) { 121 if((dep[u] - dep[v]) & (1 << k)) u = fa[k][u]; 122 } 123 if(u == v) return u; 124 for(int k = MLOG - 1; k >= 0; --k) { 125 if(fa[k][u] != fa[k][v]) 126 u = fa[k][u], v = fa[k][v]; 127 } 128 return fa[0][u]; 129 } 130 /// modui 131 bool vis[MAXV]; 132 int diff, cnt[MAXV]; 133 134 void xorNode(int u) { 135 if(vis[u]) vis[u] = false, diff -= (--cnt[val[u]] == 0); 136 else vis[u] = true, diff += (++cnt[val[u]] == 1); 137 } 138 139 void xorPathWithoutLca(int u, int v) { 140 if(dep[u] < dep[v]) swap(u, v); 141 while(dep[u] != dep[v]) 142 xorNode(u), u = fa[0][u]; 143 while(u != v) 144 xorNode(u), u = fa[0][u], 145 xorNode(v), v = fa[0][v]; 146 } 147 148 void moveNode(int u, int v, int taru, int tarv) { 149 xorPathWithoutLca(u, taru); 150 xorPathWithoutLca(v, tarv); 151 //printf("debug %d %d\n", ask_lca(u, v), ask_lca(taru, tarv)); 152 xorNode(ask_lca(u, v)); 153 xorNode(ask_lca(taru, tarv)); 154 } 155 156 void make_ans() { 157 for(int i = 0; i < m; ++i) ask[i].adjust(); 158 sort(ask, ask + m); 159 int nowu = 1, nowv = 1; xorNode(1); 160 for(int i = 0; i < m; ++i) { 161 moveNode(nowu, nowv, ask[i].u, ask[i].v); 162 ans[ask[i].id] = diff; 163 nowu = ask[i].u, nowv = ask[i].v; 164 } 165 } 166 167 void print_ans() { 168 for(int i = 0; i < m; ++i) 169 printf("%d\n", ans[i]); 170 } 171 172 void solve() { 173 read(); 174 init_block(); 175 init_lca(); 176 make_ans(); 177 print_ans(); 178 } 179 180 }; 181 182 int main() { 183 Bilibili::solve(); 184 }