Loading

CSU 1811: Tree Intersection(线段树启发式合并||map启发式合并)

http://acm.csu.edu.cn/csuoj/problemset/problem?pid=1811

题意:给出一棵树,每一个结点有一个颜色,然后依次删除树边,问每次删除树边之后,分开的两个连通块里面的颜色交集数是多少,即公有的颜色数。

思路:可以像树形DP一样,先处理出儿子结点,然后回溯的时候和儿子结点的子树合并,更新父结点的子树,然后更新父结点的答案。

枚举删除的边。以u为子树里面放的是某个颜色有多少,然后和一开始统计的某种颜色的总数比较,如果树里面这个颜色的数目小于这个颜色的总数,那么这个颜色就肯定有一些是在另一个连通块里面,那么就是公有的,对答案有贡献。

如果树里面这个颜色的数目为0或者等于这个颜色的总数,说明不是公有的,那么对答案就没贡献。

因为直接合并的话O(n^2)的复杂度太大,因此用启发式合并达到O(nlogn)。

写了两种方法,都很好理解。

线段树:空间不足,因此要动态开辟结点

 1 #include <bits/stdc++.h>
 2 using namespace std;
 3 #define N 100010
 4 struct node {
 5     int val, cnt, l, r; // val是颜色c的个数,sum是答案的个数
 6 } tree[N*50];
 7 struct Edge {
 8     int v, nxt, id;
 9 } edge[N*2];
10 int n, col[N], sum[N], head[N], tot, sz, root[N], ans[N], e[N];
11 // 共有的就是交集就是答案
12 void Add(int u, int v, int id) {
13     edge[tot] = (Edge) { v, head[u], id }; head[u] = tot++;
14     edge[tot] = (Edge) { u, head[v], id }; head[v] = tot++;
15 }
16 
17 void PushUp(int now) {
18     tree[now].cnt = tree[tree[now].l].cnt + tree[tree[now].r].cnt;
19 }
20 
21 int Build(int l, int r, int c) {
22     int now = ++sz;
23     tree[now].l = tree[now].r = 0;
24     int m = (l + r) >> 1;
25     if(l == r) {
26         tree[now].val = 1;
27         tree[now].cnt = (tree[now].val < sum[c] ? 1 : 0);
28         return now;
29     }
30     if(c <= m) tree[now].l = Build(l, m, c);
31     else tree[now].r = Build(m + 1, r, c);
32     PushUp(now);
33     return now;
34 }
35 
36 void Merge(int &rt1, int rt2, int l, int r) {
37     if(!rt1 || !rt2) {
38         if(!rt1) rt1 = rt2; // 小的变成大的
39         return ;
40     }
41     if(l == r) {
42         tree[rt1].val += tree[rt2].val;
43         tree[rt1].cnt = (tree[rt1].val < sum[l] ? 1 : 0);
44         return ;
45     }
46     int m = (l + r) >> 1;
47     Merge(tree[rt1].l, tree[rt2].l, l, m);
48     Merge(tree[rt1].r, tree[rt2].r, m + 1, r);
49     PushUp(rt1);
50 }
51 
52 void DFS(int u, int fa, int id) {
53     root[u] = Build(1, n, col[u]);
54     for(int i = head[u]; ~i; i = edge[i].nxt) {
55         int v = edge[i].v;
56         if(v == fa) continue;
57         DFS(v, u, edge[i].id);
58         Merge(root[u], root[v], 1, n);
59     }
60     if(id) e[id] = tree[root[u]].cnt;
61 }
62 
63 int main() {
64     while(~scanf("%d", &n)) {
65         memset(sum, 0, sizeof(sum));
66         for(int i = 1; i <= n; i++) scanf("%d", &col[i]), sum[col[i]]++;
67         memset(head, -1, sizeof(head)); sz = 0, tot = 0;
68         for(int i = 1; i < n; i++) {
69             int u, v; scanf("%d%d", &u, &v);
70             Add(u, v, i);
71         }
72         DFS(1, -1, 0);
73         for(int i = 1; i < n; i++)
74             printf("%d\n", e[i]);
75     }
76     return 0;
77 }

 

map:

 1 #include <bits/stdc++.h>
 2 using namespace std;
 3 #define N 100010
 4 struct Edge {
 5     int v, nxt, id;
 6 } edge[N*2];
 7 map<int, int> num[N];
 8 int n, col[N], sum[N], cnt[N], e[N], head[N], tot;
 9 
10 void Add(int u, int v, int id) {
11     edge[tot] = (Edge) { v, head[u], id }; head[u] = tot++;
12     edge[tot] = (Edge) { u, head[v], id }; head[v] = tot++;
13 }
14 
15 void DFS(int u, int fa, int id) {
16     num[u][col[u]] = 1; cnt[u] = num[u][col[u]] < sum[col[u]] ? 1 : 0;
17     for(int i = head[u]; ~i; i = edge[i].nxt) {
18         int v = edge[i].v, idd = edge[i].id;
19         if(v == fa) continue;
20         DFS(v, u, idd);
21         if(num[u].size() < num[v].size()) // 启发式合并
22             swap(num[u], num[v]), swap(cnt[u], cnt[v]);
23         for(map<int,int>::iterator it = num[v].begin(); it != num[v].end(); it++) {
24             int key = it->first, cc = it->second;
25             if(num[u][key] + cc < sum[key] && num[u][key] == 0) cnt[u]++; // 如果之前没被算过,并且是共有的就要加上
26             if(num[u][key] + cc == sum[key] && num[u][key] > 0) cnt[u]--; // 如果之前被算过,并且是特有的就要减去
27             num[u][key] += cc;
28         }
29     }
30     if(id) e[id] = cnt[u];
31 }
32 
33 int main() {
34     while(~scanf("%d", &n)) {
35         memset(sum, 0, sizeof(sum));
36         memset(cnt, 0, sizeof(cnt));
37         for(int i = 1; i <= n; i++) scanf("%d", &col[i]), sum[col[i]]++, num[i].clear();
38         memset(head, -1, sizeof(head)); tot = 0;
39         for(int i = 1; i < n; i++) {
40             int u, v; scanf("%d%d", &u, &v);
41             Add(u, v, i);
42         }
43         DFS(1, -1, 0);
44         for(int i = 1; i < n; i++) printf("%d\n", e[i]);
45     }
46     return 0;
47 }

 

posted @ 2017-05-03 14:43  Shadowdsp  阅读(530)  评论(0编辑  收藏  举报