树上启发式合并学习笔记
参考链接:https://www.cnblogs.com/zzqsblog/p/6146916.html
树上启发式合并可以在O(nlogn)的时间内解决很多关于树的无修改询问,比如每颗子树中出现的最多的颜色等。
树上启发式算法流程:
1 先dfs一次,记录每个节点的重儿子。(和树剖的预处理差不多)。
2 第二次dfs,这次dfs用来解决关于子树的询问的问题,分为3步:
1 先dfs非重儿子节点,即解决非重儿子子树的询问问题,求解后把信息清零。
2 dfs重儿子节点,即解决重儿子子树的询问问题,求解后不用把信息清零。
3 用另一个dfs函数统计非重儿子的子树的信息(只是统计信息),之后便得到了这个节点的询问的答案。
4 根据传进来的参数决定信息是否清零。
代码:
CF 600E:
#include<bits/stdc++.h> using namespace std; const int maxn =100010; int head[maxn], Next[maxn*2], ver[maxn*2]; int cnt[maxn], col[maxn], sz[maxn], son[maxn], tot; bool skip[maxn]; long long ans[maxn], sum; int mx; void add(int x, int y) { ver[++tot] = y; Next[tot] = head[x]; head[x] = tot; } void get_son(int x,int f = 0) { sz[x] = 1; for(int i =head[x]; i; i = Next[i]) { int y = ver[i]; if(y == f) continue; get_son(y,x); sz[x] += sz[y]; if(sz[y] > sz[son[x]]) son[x] = y; } } void update(int x, int f, int k) { cnt[col[x]] += k; if(k > 0 && cnt[col[x]] >= mx) { if(cnt[col[x]] > mx) sum = 0, mx = cnt[col[x]]; sum += col[x]; } for(int i = head[x]; i; i = Next[i]) { int y = ver[i]; if(y != f && !skip[y]) update(y,x,k); } } void dfs(int x, int f = 0, int kep = 0) { for(int i = head[x]; i ; i = Next[i]) { int y = ver[i]; if(y != f && y != son[x]) { dfs(y, x); } } if(son[x]) { dfs(son[x], x, 1), skip[son[x]] = 1; } update(x,f,1); ans[x] = sum; if(son[x]) skip[son[x]] = 0; if(!kep) { update(x, f, -1); mx = sum = 0; } } int main() { int n; scanf("%d", &n); for(int i = 1; i <= n; i++) scanf("%d", col + i); for(int i = 1; i < n; i++) { int x, y; scanf("%d%d", &x, &y); add(x, y); add(y, x); } get_son(1); dfs(1); for(int i = 1; i <= n; i++) printf("%lld ", ans[i]); }
CF 741D:
#include<bits/stdc++.h> using namespace std; const int maxn =500010; int head[maxn], Next[maxn*2], ver[maxn*2], edge[maxn*2]; int cnt[maxn], sz[maxn], son[maxn], tot; const int S = 'v' - 'a' + 1, INF = 1e9; int lca_deep, skip, now; int dp[maxn], re[1<<S], a[maxn], deep[maxn]; void add(int x, int y, int z) { ver[++tot] = y; edge[tot] = z; Next[tot] = head[x]; head[x] = tot; } void get_son(int x,int f = 0) { sz[x] = 1; for (int i =head[x]; i; i = Next[i]) { int y = ver[i]; if(y == f) continue; a[y] = a[x] ^ edge[i]; deep[y] = deep[x] + 1; get_son(y,x); sz[x] += sz[y]; if(sz[y] > sz[son[x]]) son[x] = y; } } void clr(int x) { re[a[x]] = -INF; } void change(int x) { now = max(now, re[a[x]] + deep[x] -lca_deep * 2); for (int i = 0; i < S; i++) { now = max(now, re[a[x] ^ (1 << i)] + deep[x] - lca_deep * 2); } } void maintain(int x) { re[a[x]] = max(re[a[x]], deep[x]); } template <void(*func)(int)> void update(int x, int f) { func(x); for (int i = head[x]; i; i = Next[i]) { int y = ver[i]; if(y != f && y != skip) update<func>(y,x); } } void dfs(int x, int f = 0, int kep = 0) { for (int i = head[x]; i ; i = Next[i]) { int y = ver[i]; if(y != f && y != son[x]) { dfs(y, x); } } if (son[x]) { dfs(son[x], x, 1), skip = son[x]; } lca_deep = deep[x]; for (int i = head[x]; i; i = Next[i]) { int y = ver[i]; if (y != f) dp[x] = max(dp[x], dp[y]); } for (int i = head[x]; i ; i = Next[i]) { int y = ver[i]; if(y != f && y != son[x]) { update<change>(y,x); update<maintain>(y,x); } } change(x), maintain(x); dp[x] = max(dp[x], now); skip = 0; if (!kep) { update<clr>(x, f); now = -INF; } } int main() { int n; scanf("%d", &n); for(int i = 0; i < (1<<S); i++) { re[i] = - INF; } for(int i = 2; i <= n; i++) { int x; char z[3]; scanf("%d%s", &x, z); add(i, x, 1 << (z[0] - 'a')); add(x, i, 1 << (z[0] - 'a')); } get_son(1); dfs(1); for(int i = 1; i <= n; i++) printf("%d ", dp[i]); }
CF 1009F:
#include<bits/stdc++.h> #define INF 0x3f3f3f3f using namespace std; const int maxn =1000010; int head[maxn], Next[maxn*2], ver[maxn*2]; int cnt[maxn], col[maxn], sz[maxn], son[maxn], deep[maxn], tot; int skip; int ans[maxn], mx, tmp; void add(int x, int y) { ver[++tot] = y; Next[tot] = head[x]; head[x] = tot; } void get_son(int x,int f = 0) { sz[x] = 1; for(int i =head[x]; i; i = Next[i]) { int y = ver[i]; if(y == f) continue; deep[y] = deep[x] + 1; get_son(y,x); sz[x] += sz[y]; if(sz[y] > sz[son[x]]) son[x] = y; } } void update(int x, int f, int k) { cnt[deep[x]] += k; if (k > 0 && cnt[deep[x]] >= mx) { if (cnt[deep[x]] > mx) { mx = cnt[deep[x]]; tmp = deep[x]; } else if (tmp > deep[x]) { tmp = deep[x]; } } for(int i = head[x]; i; i = Next[i]) { int y = ver[i]; if(y != f && y != skip) update(y,x,k); } } void dfs(int x, int f = 0, int kep = 0) { for(int i = head[x]; i ; i = Next[i]) { int y = ver[i]; if(y != f && y != son[x]) { dfs(y, x); } } if(son[x]) { dfs(son[x], x, 1), skip = son[x]; } update(x,f,1); ans[x] = tmp - deep[x]; skip = 0; if(!kep) { update(x, f, -1); mx = 0; tmp = INF; } } int main() { int n; scanf("%d", &n); for(int i = 1; i < n; i++) { int x, y; scanf("%d%d", &x, &y); add(x, y); add(y, x); } get_son(1); dfs(1); for(int i = 1; i <= n; i++) printf("%d ", ans[i]); }
CF 570D
#include<bits/stdc++.h> using namespace std; const int maxn =500010; int head[maxn], Next[maxn*2], ver[maxn*2]; int cnt[maxn], sz[maxn], son[maxn], deep[maxn], tot, skip; int ans[maxn], a[maxn]; vector<int> q[maxn], id[maxn]; bitset<26> re[maxn]; int mx; void add(int x, int y) { ver[++tot] = y; Next[tot] = head[x]; head[x] = tot; } void get_son(int x,int f = 0) { sz[x] = 1; for(int i =head[x]; i; i = Next[i]) { int y = ver[i]; if(y == f) continue; deep[y] = deep[x] + 1; get_son(y,x); sz[x] += sz[y]; if(sz[y] > sz[son[x]]) son[x] = y; } } void update(int x, int f) { re[deep[x]][a[x]] = re[deep[x]][a[x]] ^ 1; for(int i = head[x]; i; i = Next[i]) { int y = ver[i]; if(y != f && y!= skip) update(y, x); } } void dfs(int x, int f = 0, int kep = 0) { for(int i = head[x]; i ; i = Next[i]) { int y = ver[i]; if(y != f && y != son[x]) { dfs(y, x); } } if(son[x]) { dfs(son[x], x, 1), skip = son[x]; } update(x,f); for(int i = 0; i < q[x].size(); i++) { if(re[q[x][i]].count() <= 1 ) ans[id[x][i]] = 1; else ans[id[x][i]] = 0; } skip = 0; if(!kep) { update(x, f); } } int main() { int n, m; int x, y; char s[maxn]; scanf("%d%d", &n, &m); for (int i = 2; i <= n; i++) { scanf("%d", &x); add(x, i); add(i, x); } scanf("%s", s + 1); for (int i = 1; i <= n; i++) { a[i] = s[i] - 'a'; } for(int i = 1; i <= m; i++) { scanf("%d%d",&x, &y); q[x].push_back(y); id[x].push_back(i); } deep[1] = 1; get_son(1); dfs(1); for(int i = 1; i <= m; i++) { if(ans[i] == 1)printf("Yes\n"); else printf("No\n"); } }