SPOJ COT2 Count on a tree II(树上莫队)

题目链接:http://www.spoj.com/problems/COT2/

You are given a tree with N nodes.The tree nodes are numbered from 1 to N.Each node has an integer weight.

We will ask you to perfrom the following operation:

  • u v : ask for how many different integers that represent the weight of nodes there are on the path from u to v.

 

Input

In the first line there are two integers N and M.(N<=40000,M<=100000)

In the second line there are N integers.The ith integer denotes the weight of the ith node.

In the next N-1 lines,each line contains two integers u v,which describes an edge (u,v).

In the next M lines,each line contains two integers u v,which means an operation asking for how many different integers that represent the weight of nodes there are on the path from u to v.

Output

For each operation,print its result.

 

题目大意:给一棵树,每个点有一个权值。多次询问路径(a, b)上有多少个权值不同的点。

思路:参考VFK WC 2013 糖果公园 park 题解(此题比COT2要难。)

http://vfleaking.blog.163.com/blog/static/174807634201311011201627/

 

代码(2.37S):

  1 #include <bits/stdc++.h>
  2 using namespace std;
  3 
  4 const int MAXV = 40010;
  5 const int MAXE = MAXV << 1;
  6 const int MAXQ = 100010;
  7 const int MLOG = 20;
  8 
  9 namespace Bilibili {
 10 
 11 int head[MAXV], val[MAXV], ecnt;
 12 int to[MAXE], next[MAXE];
 13 int n, m;
 14 
 15 int stk[MAXV], top;
 16 int block[MAXV], bcnt, bsize;
 17 
 18 struct Query {
 19     int u, v, id;
 20     void read(int i) {
 21         id = i;
 22         scanf("%d%d", &u, &v);
 23     }
 24     void adjust() {
 25         if(block[u] > block[v]) swap(u, v);
 26     }
 27     bool operator < (const Query &rhs) const {
 28         if(block[u] != block[rhs.u]) return block[u] < block[rhs.u];
 29         return block[v] < block[rhs.v];
 30     }
 31 } ask[MAXQ];
 32 int ans[MAXQ];
 33 /// Graph
 34 void init() {
 35     memset(head + 1, -1, n * sizeof(int));
 36     ecnt = 0;
 37 }
 38 
 39 void add_edge(int u, int v) {
 40     to[ecnt] = v; next[ecnt] = head[u]; head[u] = ecnt++;
 41     to[ecnt] = u; next[ecnt] = head[v]; head[v] = ecnt++;
 42 }
 43 
 44 void gethash(int a[], int n) {
 45     static int tmp[MAXV];
 46     int cnt = 0;
 47     for(int i = 1; i <= n; ++i) tmp[cnt++] = a[i];
 48     sort(tmp, tmp + cnt);
 49     cnt = unique(tmp, tmp + cnt) - tmp;
 50     for(int i = 1; i <= n; ++i)
 51         a[i] = lower_bound(tmp, tmp + cnt, a[i]) - tmp + 1;
 52 }
 53 
 54 void read() {
 55     scanf("%d%d", &n, &m);
 56     for(int i = 1; i <= n; ++i) scanf("%d", &val[i]);
 57     gethash(val, n);
 58     init();
 59     for(int i = 1, u, v; i < n; ++i) {
 60         scanf("%d%d", &u, &v);
 61         add_edge(u, v);
 62     }
 63     for(int i = 0; i < m; ++i) ask[i].read(i);
 64 }
 65 /// find_block
 66 void add_block(int &cnt) {
 67     while(cnt--) block[stk[--top]] = bcnt;
 68     bcnt++;
 69     cnt = 0;
 70 }
 71 
 72 void rest_block() {
 73     while(top) block[stk[--top]] = bcnt - 1;
 74 }
 75 
 76 int dfs_block(int u, int f) {
 77     int size = 0;
 78     for(int p = head[u]; ~p; p = next[p]) {
 79         int v = to[p];
 80         if(v == f) continue;
 81         size += dfs_block(v, u);
 82         if(size >= bsize) add_block(size);
 83     }
 84     stk[top++] = u;
 85     size++;
 86     if(size >= bsize) add_block(size);
 87     return size;
 88 }
 89 
 90 void init_block() {
 91     bsize = max(1, (int)sqrt(n));
 92     dfs_block(1, 0);
 93     rest_block();
 94 }
 95 /// ask_rmq
 96 int fa[MLOG][MAXV];
 97 int dep[MAXV];
 98 
 99 void dfs_lca(int u, int f, int depth) {
100     dep[u] = depth;
101     fa[0][u] = f;
102     for(int p = head[u]; ~p; p = next[p]) {
103         int v = to[p];
104         if(v != f) dfs_lca(v, u, depth + 1);
105     }
106 }
107 
108 void init_lca() {
109     dfs_lca(1, -1, 0);
110     for(int k = 0; k + 1 < MLOG; ++k) {
111         for(int u = 1; u <= n; ++u) {
112             if(fa[k][u] == -1) fa[k + 1][u] = -1;
113             else fa[k + 1][u] = fa[k][fa[k][u]];
114         }
115     }
116 }
117 
118 int ask_lca(int u, int v) {
119     if(dep[u] < dep[v]) swap(u, v);
120     for(int k = 0; k < MLOG; ++k) {
121         if((dep[u] - dep[v]) & (1 << k)) u = fa[k][u];
122     }
123     if(u == v) return u;
124     for(int k = MLOG - 1; k >= 0; --k) {
125         if(fa[k][u] != fa[k][v])
126             u = fa[k][u], v = fa[k][v];
127     }
128     return fa[0][u];
129 }
130 /// modui
131 bool vis[MAXV];
132 int diff, cnt[MAXV];
133 
134 void xorNode(int u) {
135     if(vis[u]) vis[u] = false, diff -= (--cnt[val[u]] == 0);
136     else vis[u] = true, diff += (++cnt[val[u]] == 1);
137 }
138 
139 void xorPathWithoutLca(int u, int v) {
140     if(dep[u] < dep[v]) swap(u, v);
141     while(dep[u] != dep[v])
142         xorNode(u), u = fa[0][u];
143     while(u != v)
144         xorNode(u), u = fa[0][u],
145         xorNode(v), v = fa[0][v];
146 }
147 
148 void moveNode(int u, int v, int taru, int tarv) {
149     xorPathWithoutLca(u, taru);
150     xorPathWithoutLca(v, tarv);
151     //printf("debug %d %d\n", ask_lca(u, v), ask_lca(taru, tarv));
152     xorNode(ask_lca(u, v));
153     xorNode(ask_lca(taru, tarv));
154 }
155 
156 void make_ans() {
157     for(int i = 0; i < m; ++i) ask[i].adjust();
158     sort(ask, ask + m);
159     int nowu = 1, nowv = 1; xorNode(1);
160     for(int i = 0; i < m; ++i) {
161         moveNode(nowu, nowv, ask[i].u, ask[i].v);
162         ans[ask[i].id] = diff;
163         nowu = ask[i].u, nowv = ask[i].v;
164     }
165 }
166 
167 void print_ans() {
168     for(int i = 0; i < m; ++i)
169         printf("%d\n", ans[i]);
170 }
171 
172 void solve() {
173     read();
174     init_block();
175     init_lca();
176     make_ans();
177     print_ans();
178 }
179 
180 };
181 
182 int main() {
183     Bilibili::solve();
184 }
View Code

 

posted @ 2015-02-01 17:36  Oyking  阅读(1974)  评论(0编辑  收藏  举报