启发式合并
启发式合并
先看看什么是启发式算法。
启发式算法可以这样定义:一个基于直观或经验构造的算法,在可接受的花费(指计算时间和空间)下给出待解决组合优化问题每一个实例的一个可行解,该可行解与最优解的偏离程度一般不能被预计。现阶段,启发式算法以仿自然体算法为主,主要有蚁群算法、模拟退火法、神经网络等。
\(from\) 百度百科
再来看看启发式合并
这个东西的原理很简单,就是你考虑合并 \(2\) 个数据结构,如果直接合并,复杂度去最坏情况,为 \(O(n^2)\)
但是,你只需要记录一下 \(size\) 然后每次跑的时候把 \(size\) 小的合并到 \(size\) 大的。
这个算法感觉和之前的暴力也没什么区别就吧。
但是,你会发现每个元素最多合并 \(log_m\) 次,\(n\) 个元素,最坏复杂度 \(O(n * log_n)\)
至于是为什么,不多赘述。因为不会。。
其实启发式合并和线段树合并一样,只是一种工具,一种优化。
不过,启发式合并可以适用于各种不同的数据结构,比如 \(set\),\(splay\) 等等。
伪代码就不放了,因为它适用面太广了,没有什么固定的模板。
例题:
题目大意
你有 \(n\) 个布丁,每个布丁都有它的颜色,一共有 \(m\) 次操作。
每次操作可以:
第一,把所有颜色为 \(x\) 的布丁变成颜色为 \(y\) 的布丁。
第二,问当前一共有多少段颜色。
题解
对于每个颜色维护一个数据结构,在进行操作 \(1\) 的时候,就把 \(x\) 和 \(y\) 所在数据结构进行启发式合并。
思路很简单,但是代码实现却需要处理很多细节。
代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <string>
#include <queue>
#define maxn 200020
#define ls x << 1
#define rs x << 1 | 1
#define inf 0x3f3f3f3f
#define inc(i) (++ (i))
#define dec(i) (-- (i))
#define mid ((l + r) >> 1)
// #define int long long
#define XRZ 1000000003
#define debug() puts("XRZ TXDY");
#define mem(i, x) memset(i, x, sizeof(i));
#define Next(i, u) for(register int i = head[u]; i ; i = e[i].nxt)
#define file(x) freopen(#x".in", "r", stdin), freopen(#x".out", "w", stdout);
#define Rep(i, a, b) for(register int i = (a) , i##Limit = (b) ; i <= i##Limit ; inc(i))
#define Dep(i, a, b) for(register int i = (a) , i##Limit = (b) ; i >= i##Limit ; dec(i))
int dx[10] = {1, -1, 0, 0};
int dy[10] = {0, 0, 1, -1};
using namespace std;
inline int read() {
register int x = 0, f = 1; register char c = getchar();
while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
return x * f;
} int ans, a[maxn], head[maxn], nxt[maxn], num[maxn], S[maxn], fa[maxn];
void merge(int x, int y) {
for(int i = head[x]; i; i = nxt[i]) ans -= (a[i - 1] == y) + (a[i + 1] == y);
for(int i = head[x]; i; i = nxt[i]) a[i] = y;
nxt[S[x]] = head[y], head[y] = head[x], num[y] += num[x];
head[x] = S[x] = num[x] = 0;
}
signed main() { int n = read(), m = read();
Rep(i, 1, n) { a[i] = read();
fa[a[i]] = a[i], ans += a[i] != a[i - 1];
if(head[a[i]] == 0) S[a[i]] = i;
inc(num[a[i]]); nxt[i] = head[a[i]], head[a[i]] = i;
} Rep(i, 1, m) { int opt = read();
if(opt == 1) { int x = read(), y = read();
if(x == y) continue;
if(num[fa[x]] > num[fa[y]]) swap(fa[x], fa[y]);
if(num[fa[x]] == 0) continue;
merge(fa[x], fa[y]);
} else printf("%d\n", ans);
}
return 0;
}
题目描述
给你 \(n\) 个节点,取出每个节点都要付出相应的代价。
你可以一次取多个 不在 一条从 \(1\) 出发的链 的节点一起取出,代价为其中最大的。
题解
首先考虑部分分,
如果你只有链的情况。
考虑贪心,最大的只能和另一条链上最大的匹配。
不断配对就可以了。
再扩展到树上。
是不是对于每 \(2\) 条链做一次合并,最后的复杂度是 \(O(n ^ 2)\) 的
这个是不可以过的,所以我们需要考虑优化,既然是在启发式合并里面,那就显然是用启发式合并去优化这个暴力合并。
记录 \(size\) 把小的合并到大的即可,复杂度 \(O(n * log_n)\)
代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
#include <queue>
#define maxn 200001
#define ls x << 1
#define rs x << 1 | 1
#define inf 0x3f3f3f3f
#define inc(i) (++ (i))
#define dec(i) (-- (i))
#define mid ((l + r) >> 1)
#define int long long
#define XRZ 1000000003
#define debug() puts("XRZ TXDY");
#define mem(i, x) memset(i, x, sizeof(i));
#define Next(i, u) for(register int i = head[u]; i ; i = e[i].nxt)
#define file(x) freopen(#x".in", "r", stdin), freopen(#x".out", "w", stdout);
#define Rep(i, a, b) for(register int i = (a) , i##Limit = (b) ; i <= i##Limit ; inc(i))
#define Dep(i, a, b) for(register int i = (a) , i##Limit = (b) ; i >= i##Limit ; dec(i))
int dx[10] = {1, -1, 0, 0};
int dy[10] = {0, 0, 1, -1};
using namespace std;
inline int read() {
register int x = 0, f = 1; register char c = getchar();
while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
return x * f;
} int Ans, a[maxn];//, head[maxn];
priority_queue<int> Q[maxn]; vector<int> s, qwq[maxn];
// struct node { int nxt, to;} e[maxn << 1];
// void add(int x, int y) { e[inc(cnt)] = (node) {head[x], y}; head[x] = cnt;}
void merge(int x, int y) {
if(Q[x].size() < Q[y].size()) swap(Q[x], Q[y]);
while(Q[y].size()) {
s.push_back(max(Q[x].top(), Q[y].top()));
Q[x].pop(), Q[y].pop();
} while(s.size()) Q[x].push(s.back()), s.pop_back();
}
void Dfs(int x) {
// Next(i, x) { int v = e[i].to; Dfs(v); merge(x, v);}
Rep(i, 0, qwq[x].size() - 1) Dfs(qwq[x][i]), merge(x, qwq[x][i]);
Q[x].push(a[x]);
}
signed main() { int n = read();
Rep(i, 1, n) a[i] = read();
Rep(i, 2, n) { int u = read(); qwq[u].push_back(i); }
// Rep(i, 2, n) { int u = read(); add(i, u);}
Dfs(1); while(Q[1].size()) Ans += Q[1].top(), Q[1].pop();
printf("%lld", Ans);
return 0;
}