浅谈DSU on tree
writen by yzh on 2022/20/9
前言
fwyzh发现自己居然从来没写过dsu on tree的题。某天在nflsoj上还是败给了dsu on tree。便有此文。(时间不够写的比较水)
对dsu on tree 理解较浅
引入
DSU——Disjoint Set Union被oier情切的称为启发式合并。
启发式顾名思义是基于人类的经验及直观感觉,对一些算法进行的优化。所以为什么我没有这些直观感觉qwq
最常见的启发式合并就是并查集的按秩合并。显然一个人的直觉会认为将小的集合合并到大的集合中复杂度更优秀。所以按秩合并搭配记忆化便能达到\(O(n)\)优秀的复杂度
算法思想
可解决问题
对于一类树上问题,形如可通过枚举某个节点产生的贡献来求解的问题。
对于这类问题显然可以对一颗树进行dfs然后通过一些数据结构进行维护。但这样做复杂度往往会成为\(O(n^2log_2n)\)然而正解往往需要\(O(nlog_2^2n)\)甚至是\(O(nlog_2n)\)的复杂度才能通过。
对于这类问题学过重链剖分并且思维较活跃的佬可能能想到这类dsu on tree的做法。
树上启发式合并就是基于重链剖分的思想,对上述dfs进行优化。不懂重链剖分的同学可以先学一波重剖(但其实也没必要)。
重剖有这么一条性质:对于树上任意一个点,其到根最多经过不超过log n条轻链和重链
基于这条性质,我们可以大胆猜测dsu on tree的做法。
流程
对于每一个节点,我们暴力将轻儿子的贡献合并到根节点。
对于重儿子单独进行dfs,然后将重儿子的信息合并到根节点上。
看似十分暴力的复杂度,实际能达到优雅的\(O(nlog_2n)\)
证明可以感性理解一下,利用重剖性质,一个点最多被合并\(log_2 n\)次。
这就是树上启发式合并的大体思路。
例题1:lexiyvv's tree
题面
lexiyvv的家长要求他在一棵点数为n(n<=2e5)的树上花至少k小时跑步,且他跑步经过的路程必须是树上的一条路径。树上的第i条边连接点和,他通过这条边所花的时间固定为小时,求他跑步所花的最小时间。
思路
这道题暴力思路很显然,dfs暴力枚举每个根节点u,将u的儿子v所在的子树中所有的点到u的距离暴力放进set里,对于每一个子节点u所在子树的每一个子节点中可用lower_bound求出set中点v'使得与当前子节点的距离大于等于k。
对于上述dfs,显然我们可以用DSU的思想优化。
暴力枚举每个轻儿子所在子树中的答案。然后单独处理重儿子的答案。答案取最小即可。
代码&思路の细节
重剖部分(显然):
void dfs(int u,int fa,int dept){
siz[u]=1;
L[u]=++cnt;
id[cnt]=u;
dep[u]=dept;
int maxn=0;
for(int i=head[u];i;i=edge[i].nt){
int v=edge[i].v;
if(v==fa) continue;
dfs(v,u,dept+edge[i].w);
siz[u]+=siz[v];
if(siz[v]>maxn){
maxn=siz[v];
son[u]=v;
}
}
R[u]=cnt;
}
dsu和暴力加点部分:
set<int> s;
void add_tree(int u){
for(int i=L[u];i<=R[u];++i){
s.insert(dep[id[i]]);
}
}
int ans=1e18;
void dsu(int u,int fa){
// cout<<u<<endl;
s.insert(1e18);
for(int i=head[u];i;i=edge[i].nt){
int v=edge[i].v;
if(v==fa||v==son[u]) continue;
dsu(v,u);
s.clear();
}
if(son[u]) dsu(son[u],u);
s.insert(1e18);
for(int i=head[u];i;i=edge[i].nt){
int v=edge[i].v;
if(v==fa||v==son[u]) continue;
for(int i=L[v];i<=R[v];++i){
int num=*s.lower_bound(k+2*dep[u]-dep[id[i]]);
if(num!=1e18){
ans=min(ans,num+dep[id[i]]-2*dep[u]);
}
}
add_tree(v);
}
s.insert(dep[u]);
int num=*s.lower_bound(k+dep[u]);
if(num!=1e18){
ans=min(ans,num-dep[u]);
}
}
完整代码:
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=5e5+7;
int n,k,head[N],tot;
int son[N],L[N],R[N],cnt,siz[N],dep[N],id[N];
struct node
{
/* data */
int v,w,nt;
}edge[N*2];
void add(int u,int v,int w){
edge[++tot]={v,w,head[u]};head[u]=tot;
}
void dfs(int u,int fa,int dept){
siz[u]=1;
L[u]=++cnt;
id[cnt]=u;
dep[u]=dept;
int maxn=0;
for(int i=head[u];i;i=edge[i].nt){
int v=edge[i].v;
if(v==fa) continue;
dfs(v,u,dept+edge[i].w);
siz[u]+=siz[v];
if(siz[v]>maxn){
maxn=siz[v];
son[u]=v;
}
}
R[u]=cnt;
}
set<int> s;
void add_tree(int u){
for(int i=L[u];i<=R[u];++i){
s.insert(dep[id[i]]);
}
}
int ans=1e18;
void dsu(int u,int fa){
// cout<<u<<endl;
s.insert(1e18);
for(int i=head[u];i;i=edge[i].nt){
int v=edge[i].v;
if(v==fa||v==son[u]) continue;
dsu(v,u);
s.clear();
}
if(son[u]) dsu(son[u],u);
s.insert(1e18);
for(int i=head[u];i;i=edge[i].nt){
int v=edge[i].v;
if(v==fa||v==son[u]) continue;
for(int i=L[v];i<=R[v];++i){
int num=*s.lower_bound(k+2*dep[u]-dep[id[i]]);
if(num!=1e18){
ans=min(ans,num+dep[id[i]]-2*dep[u]);
}
}
add_tree(v);
}
s.insert(dep[u]);
int num=*s.lower_bound(k+dep[u]);
if(num!=1e18){
ans=min(ans,num-dep[u]);
}
}
signed main(){
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);
scanf("%lld%lld",&n,&k);
for(int i=1;i<n;++i){
int u,v,w;
scanf("%lld%lld%lld",&u,&v,&w);
add(u,v,w);
add(v,u,w);
}
dfs(1,0,0);
// cout<<";;"<<endl;
// for(int i=1;i<=n;++i) cout<<
dsu(1,0);
if(ans==1e18){
puts("-1");
}
else printf("%lld",ans);
return 0;
}
例题2:树上统计
题面
给一个\(n(n<=1e5)\)个点的树,记\(f(l,r)\quad (l<=r)\)为\(l\)到\(r\)所有点连通所需最少边数。求
思路
考虑每条边对答案的贡献
以1号点为根节点,对于每条边(u,v),我们记v为更深的点,v的子树中所有点标记为1,其余点标记为0。那么这条边(u,v)对答案做出的贡献数,就是(0,1)点对数。
这里的(0,1)点对是指节点序号相邻的两个数为一个点对
在一个树中可以用(0,1)点对数=总点对数-(0,0)点对数-(1,1)点对数
的方法来求出(0,1)点对数。
对于维护(0,0)和(1,1)点对数显然可以用线段树来维护。
暴力的做法就是dfs枚举每条边,然后用线段树来维护。这样的复杂度为\(O(n^2log_2n)\)
显然用上文中讲到的树上启发式合并可以直接优化到\(O(nlog_2^2n)\)
代码&思路の细节
首先对这棵树进行重剖:
void dfs(int now) {
L[now] = ++T; // DSU用的DFS序
P[T] = now;
sz[now] = 1;
for (int i = head[now]; i; i = edge[i].nx) {
int nxt = edge[i].to;
if (nxt == fa[now])
continue;
fa[nxt] = now;
dfs(nxt);
sz[now] += sz[nxt];
if (sz[son[now]] < sz[nxt])
son[now] = nxt;
}
R[now] = T;
}
然后进行dfs操作,对与每个轻儿子暴力建树求解,对于重儿子单独dfs进行求解(欲查看线段树部分请看完整代码):
void DSU(int now) { // DSU板子
for (int i = head[now]; i; i = edge[i].nx) {
int nxt = edge[i].to;
if (nxt == fa[now] || nxt == son[now])
continue;
DSU(nxt), Addtree(nxt, 0);
}
if (son[now])
DSU(son[now]);
for (int i = head[now]; i; i = edge[i].nx) {
int nxt = edge[i].to;
if (nxt == fa[now] || nxt == son[now])
continue;
Addtree(nxt, 1);
}
AddPoint(now, 1);
if (now != 1)
ans += 1ll * n * (n + 1) / 2 - S.Query(); //这里把每条边压在了离根远的那个节点上,故根结点不用算答案
}
以下为完整代码:
#include <cstdio>
#define M 100005
struct E {
int to, nx;
} edge[M << 1];
int tot, head[M];
void Addedge(int a, int b) {
edge[++tot].to = b;
edge[tot].nx = head[a];
head[a] = tot;
}
void Read(int &x) { //快速读入
char ch = getchar();
x = 0;
bool f = 0;
while (ch < '0' || ch > '9') {
if (ch == '-')
f = 1;
ch = getchar();
}
while ('0' <= ch && ch <= '9') x = (x << 1) + (x << 3) + ch - '0', ch = getchar();
if (f)
x = -x;
}
struct Segment {
struct node {
int L, R, ls[2], rs[2]; // ls---->lsum rs---->rsum
long long sum; // L,R都在这个区间内的答案
int len() { return R - L + 1; }
} tree[M << 2];
int pos[M];
void Up(int p) {
int lson = p << 1, rson = p << 1 | 1;
tree[p].sum = tree[lson].sum + tree[rson].sum; //先收集答案
tree[p].ls[0] = tree[lson].ls[0], tree[p].ls[1] = tree[lson].ls[1];
tree[p].rs[0] = tree[rson].rs[0], tree[p].rs[1] = tree[rson].rs[1];
if (tree[p].ls[0] == tree[lson].len())
tree[p].ls[0] += tree[rson].ls[0];
if (tree[p].rs[0] == tree[rson].len())
tree[p].rs[0] += tree[lson].rs[0];
if (tree[p].ls[1] == tree[lson].len())
tree[p].ls[1] += tree[rson].ls[1];
if (tree[p].rs[1] == tree[rson].len())
tree[p].rs[1] += tree[lson].rs[1];
tree[p].sum += 1ll * tree[lson].rs[0] * tree[rson].ls[0] + 1ll * tree[lson].rs[1] * tree[rson].ls[1];
//中轴左右串合并更新答案
}
void Build(int L, int R, int p) {
tree[p].L = L, tree[p].R = R;
if (L == R) {
pos[L] = p; //存下线段树叶子节点的编号,为UPdata速度更快
tree[p].ls[0] = tree[p].rs[0] = 1; //一开始序列全为0
tree[p].ls[1] = tree[p].rs[1] = 0;
tree[p].sum = 1; //初始值
return;
}
int mid = (L + R) >> 1;
Build(L, mid, p << 1);
Build(mid + 1, R, p << 1 | 1);
Up(p);
}
void Updata(int x, int d) { //若看不懂可以写正常版本的updata
int p = pos[x];
tree[p].ls[0] = tree[p].rs[0] = !d; //更线段树新叶子节点
tree[p].ls[1] = tree[p].rs[1] = d;
tree[p].sum = 1;
p >>= 1;
while (p) Up(p), p >>= 1; //一路更新直到线段树根结点
}
long long Query() { return tree[1].sum; }
} S;
int n;
int fa[M], son[M], sz[M];
int L[M], R[M], P[M], T;
void dfs(int now) {
L[now] = ++T; // DSU用的DFS序
P[T] = now;
sz[now] = 1;
for (int i = head[now]; i; i = edge[i].nx) {
int nxt = edge[i].to;
if (nxt == fa[now])
continue;
fa[nxt] = now;
dfs(nxt);
sz[now] += sz[nxt];
if (sz[son[now]] < sz[nxt])
son[now] = nxt;
}
R[now] = T;
}
void AddPoint(int x, int d) {
S.Updata(x, d); //在线段树上更新答案
}
void Addtree(int x, int d) {
for (int i = L[x]; i <= R[x]; i++) AddPoint(P[i], d);
}
long long ans = 0;
void DSU(int now) { // DSU板子
for (int i = head[now]; i; i = edge[i].nx) {
int nxt = edge[i].to;
if (nxt == fa[now] || nxt == son[now])
continue;
DSU(nxt), Addtree(nxt, 0);
}
if (son[now])
DSU(son[now]);
for (int i = head[now]; i; i = edge[i].nx) {
int nxt = edge[i].to;
if (nxt == fa[now] || nxt == son[now])
continue;
Addtree(nxt, 1);
}
AddPoint(now, 1);
if (now != 1)
ans += 1ll * n * (n + 1) / 2 - S.Query(); //这里把每条边压在了离根远的那个节点上,故根结点不用算答案
}
void Solve() {
dfs(1);
S.Build(1, n, 1);
DSU(1);
printf("%lld\n", ans);
}
int main() {
freopen("treecnt.in", "r", stdin);
freopen("treecnt.out", "w", stdout);
Read(n);
for (int i = 1; i < n; i++) {
int a, b;
Read(a), Read(b);
Addedge(a, b);
Addedge(b, a);
}
Solve();
return 0;
}
题外话
很多佬发现,这道题还有个复杂度更优秀的做法——线段树合并\(O(nlog_2n)\)。
由于线段树合并与本文章无关,这里只放某位佬的代码,想了解的自行学习
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 1e6 + 5;
ll C(int n) { return 1ll * n * (n + 1) / 2; }
struct segTree {
int cnt = 0, rt[N * 4];
struct node {
int L, R, lc, rc, lp, rp;
ll sum;
} tr[N * 4];
int newnode(int siz = 1) {
tr[++cnt].sum = C(siz);
tr[cnt].lp = tr[cnt].rp = siz;
return cnt;
}
void pushup(int now, int l, int r) {
int mid = (l + r) >> 1;
int len = r - l + 1, ls = tr[now].L, rs = tr[now].R;
if (!ls) {
ls = tr[ls].L = newnode(mid - l + 1);
}
if (!rs) {
rs = tr[rs].R = newnode(r - mid);
}
tr[now].lc = tr[ls].lc;
tr[now].rc = tr[rs].rc;
tr[now].lp = tr[ls].lp + (tr[ls].lp == mid - l + 1) * (tr[ls].rc == tr[rs].lc) * tr[rs].lp;
tr[now].rp = tr[rs].rp + (tr[rs].rp == r - mid) * (tr[ls].rc == tr[rs].lc) * tr[ls].rp;
tr[now].sum = tr[ls].sum + tr[rs].sum +
(tr[ls].rc == tr[rs].lc) * (-C(tr[ls].rp) - C(tr[rs].lp) + C(tr[ls].rp + tr[rs].lp));
}
void update(int &now, int l, int r, int k) {
if (!now)
now = newnode(r - l + 1);
if (l == r) {
tr[now].lc ^= 1;
tr[now].rc ^= 1;
return;
}
int mid = (l + r) >> 1;
if (k <= mid)
update(tr[now].L, l, mid, k);
if (mid + 1 <= k)
update(tr[now].R, mid + 1, r, k);
pushup(now, l, r);
}
int merge(int rta, int rtb, int l, int r) {
// cout << l << ' ' << r << endl;
if (!rta || !rtb)
return rta + rtb;
if (l == r) {
tr[rta].lc = tr[rta].rc = tr[rta].lc | tr[rtb].lc;
tr[rta].lp = tr[rta].rp = 1;
tr[rta].sum = 1;
return rta;
}
int mid = (l + r) >> 1;
tr[rta].L = merge(tr[rta].L, tr[rtb].L, l, mid);
tr[rta].R = merge(tr[rta].R, tr[rtb].R, mid + 1, r);
pushup(rta, l, r);
// cout <<"pushup"<< l << ' ' <<r << ' ' << tr[rta].sum << endl;
// int len = r - l + 1 , ls = tr[rta].L , rs = tr[rta].R;
// cout << tr[ls].rc << ' ' << tr[rs].lc << endl;
// cout << tr[ls].lc << endl;
// cout << tr[ls].sum + tr[rs].sum -C(tr[ls].rp) -C(tr[rs].lp) + C(tr[ls].rp + tr[rs].lp) <<
//endl;
return rta;
}
} s;
int n;
vector<int> a[N];
ll ans = 0;
void dfs(int u, int fa) {
s.rt[u] = s.newnode();
for (auto v : a[u]) {
if (v == fa)
continue;
dfs(v, u);
// printf("merge%d %d\n" , u , v);
s.rt[u] = s.merge(s.rt[u], s.rt[v], 1, n);
}
s.update(s.rt[u], 1, n, u);
ans = ans + C(n) - s.tr[s.rt[u]].sum;
}
int main() {
freopen("treecnt.in", "r", stdin);
freopen("treecnt.out", "w", stdout);
scanf("%d", &n);
for (int i = 1; i <= n - 1; i++) {
int x, y;
scanf("%d%d", &x, &y);
a[x].push_back(y);
a[y].push_back(x);
}
dfs(1, 0);
cout << ans << endl;
return 0;
}