CSU 1811: Tree Intersection(线段树启发式合并||map启发式合并)
http://acm.csu.edu.cn/csuoj/problemset/problem?pid=1811
题意:给出一棵树,每一个结点有一个颜色,然后依次删除树边,问每次删除树边之后,分开的两个连通块里面的颜色交集数是多少,即公有的颜色数。
思路:可以像树形DP一样,先处理出儿子结点,然后回溯的时候和儿子结点的子树合并,更新父结点的子树,然后更新父结点的答案。
枚举删除的边。以u为子树里面放的是某个颜色有多少,然后和一开始统计的某种颜色的总数比较,如果树里面这个颜色的数目小于这个颜色的总数,那么这个颜色就肯定有一些是在另一个连通块里面,那么就是公有的,对答案有贡献。
如果树里面这个颜色的数目为0或者等于这个颜色的总数,说明不是公有的,那么对答案就没贡献。
因为直接合并的话O(n^2)的复杂度太大,因此用启发式合并达到O(nlogn)。
写了两种方法,都很好理解。
线段树:空间不足,因此要动态开辟结点
1 #include <bits/stdc++.h> 2 using namespace std; 3 #define N 100010 4 struct node { 5 int val, cnt, l, r; // val是颜色c的个数,sum是答案的个数 6 } tree[N*50]; 7 struct Edge { 8 int v, nxt, id; 9 } edge[N*2]; 10 int n, col[N], sum[N], head[N], tot, sz, root[N], ans[N], e[N]; 11 // 共有的就是交集就是答案 12 void Add(int u, int v, int id) { 13 edge[tot] = (Edge) { v, head[u], id }; head[u] = tot++; 14 edge[tot] = (Edge) { u, head[v], id }; head[v] = tot++; 15 } 16 17 void PushUp(int now) { 18 tree[now].cnt = tree[tree[now].l].cnt + tree[tree[now].r].cnt; 19 } 20 21 int Build(int l, int r, int c) { 22 int now = ++sz; 23 tree[now].l = tree[now].r = 0; 24 int m = (l + r) >> 1; 25 if(l == r) { 26 tree[now].val = 1; 27 tree[now].cnt = (tree[now].val < sum[c] ? 1 : 0); 28 return now; 29 } 30 if(c <= m) tree[now].l = Build(l, m, c); 31 else tree[now].r = Build(m + 1, r, c); 32 PushUp(now); 33 return now; 34 } 35 36 void Merge(int &rt1, int rt2, int l, int r) { 37 if(!rt1 || !rt2) { 38 if(!rt1) rt1 = rt2; // 小的变成大的 39 return ; 40 } 41 if(l == r) { 42 tree[rt1].val += tree[rt2].val; 43 tree[rt1].cnt = (tree[rt1].val < sum[l] ? 1 : 0); 44 return ; 45 } 46 int m = (l + r) >> 1; 47 Merge(tree[rt1].l, tree[rt2].l, l, m); 48 Merge(tree[rt1].r, tree[rt2].r, m + 1, r); 49 PushUp(rt1); 50 } 51 52 void DFS(int u, int fa, int id) { 53 root[u] = Build(1, n, col[u]); 54 for(int i = head[u]; ~i; i = edge[i].nxt) { 55 int v = edge[i].v; 56 if(v == fa) continue; 57 DFS(v, u, edge[i].id); 58 Merge(root[u], root[v], 1, n); 59 } 60 if(id) e[id] = tree[root[u]].cnt; 61 } 62 63 int main() { 64 while(~scanf("%d", &n)) { 65 memset(sum, 0, sizeof(sum)); 66 for(int i = 1; i <= n; i++) scanf("%d", &col[i]), sum[col[i]]++; 67 memset(head, -1, sizeof(head)); sz = 0, tot = 0; 68 for(int i = 1; i < n; i++) { 69 int u, v; scanf("%d%d", &u, &v); 70 Add(u, v, i); 71 } 72 DFS(1, -1, 0); 73 for(int i = 1; i < n; i++) 74 printf("%d\n", e[i]); 75 } 76 return 0; 77 }
map:
1 #include <bits/stdc++.h> 2 using namespace std; 3 #define N 100010 4 struct Edge { 5 int v, nxt, id; 6 } edge[N*2]; 7 map<int, int> num[N]; 8 int n, col[N], sum[N], cnt[N], e[N], head[N], tot; 9 10 void Add(int u, int v, int id) { 11 edge[tot] = (Edge) { v, head[u], id }; head[u] = tot++; 12 edge[tot] = (Edge) { u, head[v], id }; head[v] = tot++; 13 } 14 15 void DFS(int u, int fa, int id) { 16 num[u][col[u]] = 1; cnt[u] = num[u][col[u]] < sum[col[u]] ? 1 : 0; 17 for(int i = head[u]; ~i; i = edge[i].nxt) { 18 int v = edge[i].v, idd = edge[i].id; 19 if(v == fa) continue; 20 DFS(v, u, idd); 21 if(num[u].size() < num[v].size()) // 启发式合并 22 swap(num[u], num[v]), swap(cnt[u], cnt[v]); 23 for(map<int,int>::iterator it = num[v].begin(); it != num[v].end(); it++) { 24 int key = it->first, cc = it->second; 25 if(num[u][key] + cc < sum[key] && num[u][key] == 0) cnt[u]++; // 如果之前没被算过,并且是共有的就要加上 26 if(num[u][key] + cc == sum[key] && num[u][key] > 0) cnt[u]--; // 如果之前被算过,并且是特有的就要减去 27 num[u][key] += cc; 28 } 29 } 30 if(id) e[id] = cnt[u]; 31 } 32 33 int main() { 34 while(~scanf("%d", &n)) { 35 memset(sum, 0, sizeof(sum)); 36 memset(cnt, 0, sizeof(cnt)); 37 for(int i = 1; i <= n; i++) scanf("%d", &col[i]), sum[col[i]]++, num[i].clear(); 38 memset(head, -1, sizeof(head)); tot = 0; 39 for(int i = 1; i < n; i++) { 40 int u, v; scanf("%d%d", &u, &v); 41 Add(u, v, i); 42 } 43 DFS(1, -1, 0); 44 for(int i = 1; i < n; i++) printf("%d\n", e[i]); 45 } 46 return 0; 47 }