启发式合并,DSU on Tree
一、启发式合并
1.1传统启发式合并
启发式合并是做的一个什么事情?
给你\(n\)个集合,令\(s_i = \lbrace i\rbrace\)
选两个集合\(x,y\),把\(y\)里面的元素全部丢到\(x\)里面,令\(s_x = s_x\cup s_y\),\(s_y = \emptyset\)
这样做\(n-1\)次之后,他们就合并成了一个集合了。
思考,如何去做?
想法1:
之前学过的并查集,好像有类似操作的感觉。但是我们把两个集合合并操作,是去令\(fa[f_x] = f_y\),这其实是一个打标记的操作,并没有真的把元素放进去。
想法2:暴力模拟
for(z:s_y)
{
把z放入s_x中;
清空s_y;
}
想法三:启发式合并:把两个集合元素合并,我们考虑把小的集合并到大的集合里是更优的。
复杂度?是\(O(n\log n)\)的
证明:
考虑每个元素的贡献。
小的\(|s_x|\)—>大的\(|s_y|\),其中\((|s_x|\leq |s_y|)\),合并之后的集合大小至少是\(2|s_x|\)
一个元素被操作一次,它的集合就变大为原来的\(2\)倍。因为我们最后的集合大小是\(O(n)\)的,那么至多合并\(O(\log n)\)次,总共有\(n\)个元素,那么时间复杂度是\(O(n\log n)\)的
例题:HNOI2009, 梦幻布丁
思路:我们最终只考虑段数,那么我们可以明确的是,把\(x\)变成\(y\),和把\(y\)变成\(x\)是一样的。但是写的时候要注意细节,比如要把\(1\)变成\(2\),但是我们考虑把\(2\)变成\(1\),那这样之后对\(2\)的操作就没有了,是不对的,所以我们还是要记录一下到底变成了什么。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 101000;
const int M = 1010000;
const int LOGN = 18;
int a[N],n,m,ans;
vector<int>pos[M];
int main()
{
cin>>n>>m;
for(int i = 1;i<=n;i++)
{
cin>>a[i];
pos[a[i]].push_back(i);
}
//把0和n+1也搞进来,就不用特判边界了。
for(int i = 1;i<=n+1;i++)
ans += a[i]!=a[i-1];
for(int i = 0;i<m;i++)
{
int op;
cin>>op;
if(op==1)
{
int x,y;
cin>>x>>y;
if(x==y)continue;
if(pos[x].size()>pos[y].size())
pos[x].swap(pos[y]);//O(1)
//注意这里我们是把两个vec给swap了,但是没办法把a里面的值也swap
if(pos[y].empty())continue;
auto modify = [&](int p,int col){
ans -= (a[p]!=a[p-1])+(a[p]!=a[p+1]);
a[p] = col;
ans += (a[p]!=a[p-1])+(a[p]!=a[p+1]);
};
int col = a[pos[y][0]];
for(auto p:pos[x])
{
modify(p,col);
pos[y].push_back(p);
}
pos[x].clear();
}
else{
cout<<ans-1<<endl;
}
}
return 0;
}
1.2启发式合并维护查询
例题:路径最小值
思路:按照边权从大到小排序,依次加入
每次把两个集并在一起,如果一个询问,一个在左边,一个在右边。我们加入当前这条边,他们两个连通了,说明,加入的这条边就是当前询问的答案。为什么呢?假设加入的这条边长度为\(k\),因为我们边权按照从大到小排序了的,之前加入的都比这个更长(\(\ge k\))。但是我们加入这条边,两个集合就连通了说明,长度刚好是\(k\)。
我们去维护两个集合,一个是点集,另一个是询问的集合。
我们\(for\)所有的询问,把一个询问\((u,v)\)放到\(u\)和\(v\)点对应的点集里。所以一个集合的询问集合就是包含了其中一个点的所有询问。
所以我们对一个询问,\(u\)看它的对应的另一个端点\(v\)在不在另一个点集里面
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 201000;
int n,q;
array<int,3>E[N];
array<int,2>Q[N];
set<int>vec[N];
map<int,int>que[N];
int ans[N],fa[N];
int find(int x)
{
if(x==fa[x])
return x;
return fa[x] = find(fa[x]);
}
int main()
{
ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
cin>>n>>q;
for(int i = 1;i<n;i++)
cin>>E[i][1]>>E[i][2]>>E[i][0];
//按照边权,从大到小排序
sort(E+1,E+n);
reverse(E+1,E+n);
for(int i = 1;i<=n;i++)
{
vec[i].insert(i);
fa[i] = i;
}
for(int i=1;i<=q;i++)
{
cin>>Q[i][0]>>Q[i][1];
que[Q[i][0]][i] = Q[i][1];
que[Q[i][1]][i] = Q[i][0];
}
for(int i = 1;i<n;i++)
{
int u = find(E[i][1]),v = find(E[i][2]);
//u->v
if(vec[u].size()>vec[v].size())
swap(v,u);
for(auto [id,w]:que[u])
{
if(vec[v].count(w))//问题解决
{
ans[id] = E[i][0];
que[v].erase(id);
}
else//否则把这个询问插入到v这个集合里面。
{
que[v][id] = w;
}
}
que[u].clear();
for(auto w:vec[u])vec[v].insert(w);
fa[u] = v;
}
for(int i = 1;i<=q;i++)
cout<<ans[i]<<"\n";
return 0;
}
1.3启发式分治
思路:我们先看对于整个区间\([1,n]\)来说,是不是存在一个点\(x\)只出现了一次。那么对于其他更小的包含\(x\)的区间肯定也是合法的。我们再递归分裂到左右两边去。
需要解决的问题:
-
如何判断\(x\)在\([l,r]\)只出现了一次?
方法:我们去记录\(pre[x]\)是不是\(<l\)并且\(nxt[x]\)是不是\(>r\)即可。
-
分治时间复杂度?很遗憾是\(O(n^2)\)的,我们寄了。那怎么办呢?
考虑用启发式分治。
对于靠边的元素:分治不均匀,我们希望更快找到它。
对于中间的元素:分下去是均匀的,那无所谓,时间复杂度是对滴。
时间复杂度\(T(n) = T(x)+T(n-x)+O(\min(x,n-x))\),其中\(T(n)\)表示解决\(n\)个元素的时间复杂度。
\(T(n) = O(n\log n)\)
例题:好序列
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 201000;
int a[N],pre[N],nxt[N],n;
bool solve(int l,int r)
{
if(l>r)return true;
for(int pl = l,pr = r;pl<=pr;pl++,pr--)
{
if(pre[pl]<l&&nxt[pl]>r)
{
return solve(l,pl-1)&&solve(pl+1,r);
}
if(pre[pr]<l&&nxt[pr]>r)
{
return solve(l,pr-1)&&solve(pr+1,r);
}
}
return false;
}
bool solve()
{
cin>>n;
for(int i = 1;i<=n;i++)
cin>>a[i];
map<int,int>pos;
for(int i = 1;i<=n;i++)
{
if(pos.count(a[i]))pre[i] = pos[a[i]];
else pre[i] = 0;
pos[a[i]] = i;
}
pos.clear();
for(int i = n;i>=1;i--)
{
if(pos.count(a[i]))nxt[i] = pos[a[i]];
else nxt[i] = n+1;
pos[a[i]] = i;
}
return solve(1,n);
}
int main()
{
ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
int t;
cin>>t;
for(int i = 1;i<=t;i++)
puts(solve()?"non-boring":"boring");
}
二、DSU on Tree
本质上是:启发式合并。树上启发式合并(dsu on tree)对于某些树上离线问题可以速度大于等于大部分算法且更易于理解和实现的算法。常用于解决“对于每个节点,询问关于其子树的某些信息”的问题。
对于每个点,我们找它最大的儿子(重儿子),其他儿子叫做轻儿子。
我们先把这个点并到重儿子在的集合里面,接下来把每个轻儿子都并到重儿子里面去。
想法\(1\):
\(u\)的集合,重儿子全部继承过来,再把轻儿子依次合并进去(for轻儿子所以元素)。
想法2:
(for轻儿子所有元素,改成for轻儿子子树里面所有节点)
分析:
- 时间复杂度\(O(n\log n)\)
空间复杂度\(O(n)\)
只需要维护一个集合
- 对于信息处理:加入、清空操作
核心代码:
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5+10;
int l[N],r[N],id[N],sz[N],hs[N],tot;
//预处理出dfs序(为了减小遍历子树的常数),每个点的重儿子
void dfs_init(int u,int fa)
{
l[u] = ++tot;
id[tot] = u;
sz[u] = 1;
hs[u] = -1;
for(auto v:e[u])
{
if(v==fa)continue;
dfs(v,u);
sz[u]+=sz[v];
if(hs[u]==-1||sz[v]>sz[ha[u]])ha[u] = v;
}
r[u] = tot;
}
void dfs_solve(int u,int fa,bool keep)//keep表示需不需要保留当前信息
{//dfs到轻儿子是不需要保留的,dfs到重儿子是需要保留的
for(auto v:e[u])if(v!=fa&&v!=hs[u]){
dfs_solve(v,u,false);
}
if(hs[u]!=-1)
dfs_solve(hs[u],u,true);//重儿子的集合
for(auto v:e[u])
{
if(v!=fa&&v!=hs[u])//v是轻儿子
{
//把v子树里面所有点加入到重儿子的集合里。
for(int x = l[v];x <= r[v];x++)
add(id[x]);
}
}
add(u);//把u本身加入
if(!keep)//不需要保留的话就清空
{
for(int x = l[u];x<=r[u];x++)
del[id[x]];
}
}
int main()
{
dfs_init(1,0);
dfs_solve(1,0,false);
return 0;
}
重要的思想:
对于以 u 为根的子树
①. 先统计它轻子树(轻儿子为根的子树)的答案,统计完后删除信息
②. 再统计它重子树(重儿子为根的子树)的答案 ,统计完后保留信息
③. 然后再将重子树的信息合并到 u上
④. 再去遍历 u 的轻子树,然后把轻子树的信息合并到 u 上
⑤. 判断 u 的信息是否需要传递给它的父节点(u 是否是它父节点的重儿子)
DSU on Tree模板
const int N = 1e5 + 10;
int n, k;
vector<int> e[N];
int l[N], r[N], id[N], sz[N], hs[N], tot;
inline void add(int u)//加入u对info的影响
{
}
inline void del(int u)//清除u对info的影响
{
}
inline void query(int k, int u)
{
}
void dfs_init(int u,int f) {
l[u] = ++tot;
id[tot] = u;
sz[u] = 1;
hs[u] = -1;
for (auto v : e[u]) {
if (v == f) continue;
dfs_init(v, u);
sz[u] += sz[v];
if (hs[u] == -1 || sz[v] > sz[hs[u]])
hs[u] = v;
}
r[u] = tot;
}
void dfs_solve(int u, int f, bool keep) {
for (auto v : e[u]) {
if (v != f && v != hs[u]) {
dfs_solve(v, u, false);
}
}
if (hs[u] != -1) {
dfs_solve(hs[u], u, true);
}
for (auto v : e[u]) {
if (v != f && v != hs[u]) {
for (int x = l[v]; x <= r[v]; x++)
query(id[x]);
for (int x = l[v]; x <= r[v]; x++)
add(id[x]);
}
}
//query(u);
add(u);
if (!keep) {
for(int x = l[u]; x <= r[u]; x++)
del(id[x]);
}
}
void solve()
{
cin>>n;
for (int i = 1; i < n; i++) {
int u, v; cin>>u>>v;
e[u].push_back(v);
e[v].push_back(u);
}
dfs_init(1, 0);
dfs_solve(1, 0, false);
}
2.1解决子树问题
例题:Lomsat gelral
题意:
- 有一棵 \(n\) 个结点的以 \(1\) 号结点为根的有根树。
- 每个结点都有一个颜色,颜色是以编号表示的, \(i\) 号结点的颜色编号为 \(c_i\)。
- 如果一种颜色在以 \(x\) 为根的子树内出现次数最多,称其在以 \(x\) 为根的子树中占主导地位。显然,同一子树中可能有多种颜色占主导地位。
- 你的任务是对于每一个 \(i\in[1,n]\),求出以 \(i\) 为根的子树中,占主导地位的颜色的编号和。
- \(n\le 10^5,c_i\le n\)
思路:\(DSU\) $ on $ \(Tree。\)
\(cnt[i]\)记录每个颜色出现次数。
我们遍历每一个节点\(u\),按照以下步骤遍历:
- 先遍历轻儿子,计算答案,不\(keep\)
- 遍历重儿子,\(keep\)它对\(cnt\)数组影响
- 再次遍历以\(u\)为节点的轻儿子的子树节点,加入这些节点的贡献,得到\(ans[u]\)
注意:除了重儿子,每次遍历完\(cnt\)要清空。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5+10;
vector<int>e[N];
int l[N],r[N],id[N],sz[N],hs[N],tot,c[N],n;
int cnt[N];//每个颜色出现次数
int maxcnt;//众数出现次数
ll sumcnt;//众数的和
ll ans[N];
void dfs_init(int u,int fa)
{
l[u] = ++tot;
id[tot] = u;
sz[u] = 1;
hs[u] = -1;
for(auto v:e[u])
{
if(v==fa)continue;
dfs_init(v,u);
sz[u]+=sz[v];
if(hs[u]==-1||sz[v]>sz[hs[u]])hs[u] = v;
}
r[u] = tot;
}
void dfs_solve(int u,int fa,bool keep)
{
for(auto v:e[u])if(v!=fa&&v!=hs[u]){
dfs_solve(v,u,false);
}
if(hs[u]!=-1)
dfs_solve(hs[u],u,true);
auto add = [&](int x){
x = c[x];
cnt[x]++;
if(cnt[x]>maxcnt)
maxcnt = cnt[x],sumcnt = 0;
if(cnt[x]==maxcnt)
sumcnt += x;
};
auto del = [&](int x){
x = c[x];
cnt[x]--;
};
for(auto v:e[u])
{
if(v!=fa&&v!=hs[u])
{
for(int x = l[v];x <= r[v];x++)
add(id[x]);
}
}
add(u);
ans[u] = sumcnt;
if(!keep)
{
maxcnt = 0;
sumcnt = 0;
for(int x = l[u];x<=r[u];x++)
del(id[x]);
}
}
int main()
{
cin>>n;
for(int i = 1;i<=n;i++)
cin>>c[i];
for(int i = 1;i<n;i++)
{
int u,v;
cin>>u>>v;
e[u].push_back(v);
e[v].push_back(u);
}
dfs_init(1,0);
dfs_solve(1,0,false);
for(int i = 1;i<=n;i++)
cout<<ans[i]<<" ";
return 0;
}
例题2:Tree Requests
题意:给定一个以 \(1\) 为根的 \(n\) 个结点的树,每个点上有一个字母(a
-z
),每个点的深度定义为该节点到 \(1\) 号结点路径上的点数。每次询问 \(a, b\) 查询以 \(a\) 为根的子树内深度为 \(b\) 的结点上的字母重新排列之后是否能构成回文串。
思路:考虑能否构成回文,其中只和字母个数的奇偶性有关。
合法条件:1. 都是偶数 2.仅有一个是奇数
先离线存储以\(u\)为根的询问。维护一个\(cnt[d][c]\)数组,表示深度为\(d\),字母为\(c\)的个数。
对于询问的话,遍历26个字母,确定奇偶性去check即可。
// AC one more times
// nndbk
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 1e9 + 7;
const int N = 5e5 + 10;
int n, m;
vector<int> e[N];
int l[N], r[N], id[N], sz[N], hs[N], tot;
string s;
int c[N],cnt[N][30],dep[N],ans[N];
vector<array<int,2>>que[N];
inline void add(int u)
{
cnt[dep[u]][c[u]]++;
}
inline void del(int u)
{
cnt[dep[u]][c[u]]--;
}
inline void query(int u)
{
for(auto [h,idx] : que[u])
{
int sum = 0,odd = 0;
for(int i = 0;i < 26; i++)
{
sum += cnt[h][i];
if(cnt[h][i] & 1)odd++;
}
if((sum&1)&&odd==1)ans[idx] = true;
else if(!(sum&1)&&!odd)ans[idx] = true;
else ans[idx] = false;
}
}
void dfs_init(int u,int f) {
l[u] = ++tot;
id[tot] = u;
sz[u] = 1;
hs[u] = -1;
dep[u] = dep[f] + 1;
for (auto v : e[u]) {
if (v == f) continue;
dfs_init(v, u);
sz[u] += sz[v];
if (hs[u] == -1 || sz[v] > sz[hs[u]])
hs[u] = v;
}
r[u] = tot;
}
void dfs_solve(int u, int f, bool keep) {
for (auto v : e[u]) {
if (v != f && v != hs[u]) {
dfs_solve(v, u, false);
}
}
if (hs[u] != -1) {
dfs_solve(hs[u], u, true);
}
for (auto v : e[u]) {
if (v != f && v != hs[u]) {
// for (int x = l[v]; x <= r[v]; x++)
// query(id[x]);
for (int x = l[v]; x <= r[v]; x++)
add(id[x]);
}
}
add(u);
query(u);
if (!keep) {
for(int x = l[u]; x <= r[u]; x++)
del(id[x]);
}
}
void solve()
{
cin>>n>>m;
for (int v = 2; v <= n; v++) {
int u; cin>>u;
e[u].push_back(v);
e[v].push_back(u);
}
cin>>s;
s = "?"+s;
for(int i = 1;i <= n; i++)
c[i] = s[i]-'a';
for(int i = 1;i <= m; i++)
{
int u,h; cin>>u>>h;//以u为根的子树内深度为h的节点上的字母重排后能否构成回文
que[u].push_back({h,i});
}
dfs_init(1, 0);
dfs_solve(1, 0, false);
for(int i = 1;i <= m; i++)
{
if(ans[i])cout<<"Yes\n";
else cout<<"No\n";
}
}
int main()
{
ios::sync_with_stdio(false); cin.tie(nullptr), cout.tie(nullptr);
solve();
return 0;
}
vp遇见的题目:F. Strange Memory(二进制拆位)
// AC one more times
// nndbk
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 1e9 + 7;
const int N = 1e5 + 10;
int n, k;
vector<int> e[N];
int l[N], r[N], id[N], sz[N], hs[N], tot;
int w[N],mp[1048576 + 10][22][2];
ll res;
inline void add(int k,int u)
{
for(int i = 0;i < 22; i++)
mp[k][i][(u>>i)&1]++;
}
inline void del(int k,int u)
{
for(int i = 0;i < 22; i++)
mp[k][i][(u>>i)&1] = 0;
}
inline void query(int k, int u)
{
for(int i = 0;i < 22; i++)
if((u>>i)&1)
res += (1ll<<i)*mp[k][i][0];
else
res += (1ll<<i)*mp[k][i][1];
}
void dfs_init(int u,int f) {
l[u] = ++tot;
id[tot] = u;
sz[u] = 1;
hs[u] = -1;
for (auto v : e[u]) {
if (v == f) continue;
dfs_init(v, u);
sz[u] += sz[v];
if (hs[u] == -1 || sz[v] > sz[hs[u]])
hs[u] = v;
}
r[u] = tot;
}
void dfs_solve(int u, int f, bool keep) {
for (auto v : e[u]) {
if (v != f && v != hs[u]) {
dfs_solve(v, u, false);
}
}
if (hs[u] != -1) {
dfs_solve(hs[u], u, true);
}
for (auto v : e[u]) {
if (v != f && v != hs[u]) {
for (int x = l[v]; x <= r[v]; x++)
query(w[id[x]]^w[u],id[x]);
for (int x = l[v]; x <= r[v]; x++)
add(w[id[x]],id[x]);
}
}
//query(u);
add(w[u],u);
if (!keep) {
for(int x = l[u]; x <= r[u]; x++)
del(w[id[x]],id[x]);
}
}
void solve()
{
cin>>n;
for(int i = 1;i <= n; i++)
cin>>w[i];
for (int i = 1; i < n; i++) {
int u, v; cin>>u>>v;
e[u].push_back(v);
e[v].push_back(u);
}
dfs_init(1, 0);
dfs_solve(1, 0, false);
//cout<<sizeof(mp)/1024/1024<<'\n';
cout<<res<<"\n";
}
int main()
{
ios::sync_with_stdio(false); cin.tie(nullptr), cout.tie(nullptr);
solve();
return 0;
}
2.2解决路径问题
例题:IOI2011, Race
题意:给一棵n个点的树,每条边有权。求一条简单路径,权值和等于k,且边的数量最小。
思路:由于:\(dis[u,v] = dep[u]+dep[v]-2*dep[LCA(u,v)]\)。
若边权和为\(k\),即想要找到\(dis[u,v] = k\)
那么我们对于一个节点\(w\),看它的子树里面的节点\(v\),看是否存在一个节点\(v\)是深度\(dep[v] = k+2*dep[LCA(u,v)]-dep[u]\),如果存在的话,我们就找到了一条边权和为\(k\)的路径。
题目要去边数最小,具体的边数也有等于\(dep'[u] + dep'[v]-2dep'[LCA(u,v)]\)。其中\(dep'\)相当于距离根节点的边数。我们要记一下\(dep\)等于某个值的最小边数。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 2e5+10;
vector<pair<int,int>>e[N];
int l[N],r[N],id[N],sz[N],hs[N],tot,c[N],n,k;
int ans;
int dep1[N];//边数的深度
ll dep2[N];//权值和的深度
map<ll,int>val;
void dfs_init(int u,int fa)
{
l[u] = ++tot;
id[tot] = u;
sz[u] = 1;
hs[u] = -1;
for(auto [v,w]:e[u])
{
if(v==fa)continue;
dep1[v] = dep1[u] + 1;
dep2[v] = dep2[u] + w;
dfs_init(v,u);
sz[u]+=sz[v];
if(hs[u]==-1||sz[v]>sz[hs[u]])hs[u] = v;
}
r[u] = tot;
}
void dfs_solve(int u,int fa,bool keep)
{
for(auto [v,w]:e[u])
if(v!=fa&&v!=hs[u])
dfs_solve(v,u,false);
if(hs[u]!=-1)
dfs_solve(hs[u],u,true);
auto query = [&](int w){
ll d2 = k + 2*dep2[u] - dep2[w];
if(val.count(d2))
ans = min(ans,val[d2]+dep1[w]-2*dep1[u]);
};
auto add = [&](int w){
if(val.count(dep2[w]))
val[dep2[w]] = min(val[dep2[w]],dep1[w]);
else
val[dep2[w]] = dep1[w];
};
for(auto [v,w]:e[u])
{
if(v!=fa&&v!=hs[u])
{
for(int x = l[v];x <= r[v];x++)
query(id[x]);
for(int x = l[v];x <= r[v];x++)
add(id[x]);
}
}
query(u),add(u);
if(!keep)
val.clear();
}
int main()
{
std::ios::sync_with_stdio(false); cin.tie(nullptr), cout.tie(nullptr);
cin>>n>>k;
for(int i = 1;i<n;i++)
{
int u,v,w;;
cin>>u>>v>>w;
u++,v++;
e[u].push_back({v,w});
e[v].push_back({u,w});
}
ans = n+1;
dfs_init(1,0);
dfs_solve(1,0,false);
if(ans>=n+1)ans = -1;
cout<<ans<<"\n";
return 0;
}