Dsu On Tree

DSU On Tree是啥

中文名:树上启发式合并/静态链分治

先考虑启发式合并是啥 :

说到启发式合并,那么常见的就是并查集了。

我们将小的集合合并到大的集合中,就可以变成\(O(n\log n)\) 神奇

让高度小的树成为高度较大的树的子树,这个优化可以称为启发式合并算法。

然后就\(O(n^2)\rightarrow O(n\log n)\)

然后我们就可以看出来了,这玩意其实和莫队一样,就是一个优化后的暴力

一般用来求解子树上无修改的问题

一般的流程

  1. 先想出暴力怎么跑,然后验证一下对不对
  2. 先处理轻儿子及其子树的贡献
  3. 再处理重儿子及其子树的贡献
  4. 处理该节点的贡献
  5. 删去轻儿子的贡献

一般的板子是这样的

int fa[N],siz[N],son[N];
void dfs1(int x){//树链剖分的第一个dfs,处理出重儿子
    siz[x] = 1;
    for(int i = head[x]; i;i = edge[i].next){
        int y = edge[i].to;
        if(y == fa[x]) continue;
        fa[y] = x;dfs1(y);
        siz[x] += siz[y];
        if(siz[son[x]] < siz[y]) son[x] = y;
    }
}
int hugeson,ans[N],res;//当前节点的重儿子,答案数组
inline void add(int x){}//看题
inline void del(int x){}//看题
inline void change(int x,int val){//处理贡献
    add/del //看题
    for(int i = head[x]; i;i = edge[i].next){
        int y = edge[i].to;
        if(y == fa[x] || y == hugeson) continue;
        change(y,val);
    }
}
void dfs(int x,bool keep){//keep 记录是否为重儿子
    for(int i = head[x]; i;i = edge[i].next){
        int y = edge[i].to;
        if(y == fa[x] || y == son[x]) continue;
        dfs(y,0);//先处理轻儿子
    }
    if(son[x]) dfs(son[x],1),hugeson = son[x];//处理重儿子
    change(x,1);ans[x] = res;hugeson = 0;//处理当前节点,记录答案
    if(!keep) change(x,-1);//将轻儿子的贡献处理出来
}

时间复杂度证明

我们像树链剖分一样定义重边和轻边(连向重儿子的为重边,其余为轻边)。对于一棵有\(n\)个节点的树:

根节点到树上任意节点的轻边数不超过\(\log n\)条。我们设根到该节点有\(x\)条轻边该节点的子树大小为\(y\),显然轻边连接的子节点的子树大小小于父亲的一半(若大于一半就不是轻边了),则\(y<\frac{n}{2^x}\),显然\(n>2^x\),所以 \(x<\log n\)

又因为如果一个节点是其父亲的重儿子,则它的子树必定在它的兄弟之中最多,所以任意节点到根的路径上所有重边连接的父节点在计算答案时必定不会遍历到这个节点,所以一个节点的被遍历的次数等于它到根节点路径上的轻边数\(+1\)(之所以要 \(+1\)是因为它本身要被遍历到),所以一个节点的被遍历次数\(=\log n+1\), 总时间复杂度则为\(O(n(\log n+1))=O(n\log n)\),输出答案花费\(O(m)\)

贺的oi-wiki

DSU On Tree 应用

先看一道例题树上数颜色

这道题做法很多,我先用了一个暴力合并set水过去了(常数大,空间大,但好像卡不了还好想?)。

点此查看代码
#include<bits/stdc++.h>
#include<bits/extc++.h>
// using namespace __gnu_pbds;
// using namespace __gnu_cxx;
using namespace std;
#define infile(x) freopen(x,"r",stdin)
#define outfile(x) freopen(x,"w",stdout)
#define errfile(x) freopen(x,"w",stderr)
#ifdef LOCAL
    FILE *InFile = infile("in.in"),*OutFile = outfile("out.out");
    // FILE *ErrFile=errfile("err.err");
#else
    FILE *Infile = stdin,*OutFile = stdout;
    //FILE *ErrFile = stderr;
#endif
using ll=long long;using ull=unsigned long long;
using db = double;using ldb = long double;
const int N = 1e5 + 10;
struct EDGE{int to,next;}edge[N<<1];
int head[N],cnt;
inline void add(int u,int v){
    edge[++cnt] = {v,head[u]};
    head[u] = cnt;
}
int a[N],n,ans[N];
set<int> have[N];
inline void Merge(int x,int y){
    for(auto i : have[y]) have[x].insert(i);
    set<int> ().swap(have[y]);
}
void dfs(int x,int fa){
    for(int i = head[x]; i;i = edge[i].next){
        int y = edge[i].to;
        if(y == fa) continue;
        dfs(y,x);
        Merge(x,y);
    }
    have[x].insert(a[x]);
    ans[x] = have[x].size();
}
inline void solve(){
    cin>>n;
    for(int i = 1,u,v;i < n; ++i){
        cin>>u>>v;
        add(u,v);add(v,u);
    }
    for(int i = 1;i <= n; ++i) cin>>a[i];
    dfs(1,0);
    int q;cin>>q;
    while(q--){
        int x;cin>>x;
        cout<<ans[x]<<'\n';
    }
}
signed main(){
    cin.tie(nullptr)->sync_with_stdio(false);
    cout.tie(nullptr)->sync_with_stdio(false);
    solve();
}

先不考虑这些乱写的做法,让我们考虑DSU On Tree的流程

  1. 考虑暴力 : 枚举子树,记录一个桶,若新出现,计数器加一
  2. 考虑优化 : 先统计轻儿子的答案,再统计重儿子的答案,记录答案,再将轻儿子的贡献删去

看代码可能更好理解

点此查看代码
#include<bits/stdc++.h>
#include<bits/extc++.h>
// using namespace __gnu_pbds;
// using namespace __gnu_cxx;
using namespace std;
#define infile(x) freopen(x,"r",stdin)
#define outfile(x) freopen(x,"w",stdout)
#define errfile(x) freopen(x,"w",stderr)
#ifdef LOCAL
    FILE *InFile = infile("in.in"),*OutFile = outfile("out.out");
    // FILE *ErrFile=errfile("err.err");
#else
    FILE *Infile = stdin,*OutFile = stdout;
    //FILE *ErrFile = stderr;
#endif
using ll=long long;using ull=unsigned long long;
using db = double;using ldb = long double;
const int N = 1e5 + 10;
int n,a[N];
namespace EDGE{
    struct EDGE{int to,next;}edge[N<<1];
    int head[N],cnt;
    inline void add(int u,int v){
        edge[++cnt] = {v,head[u]};
        head[u] = cnt;
    }
}using EDGE::edge;using EDGE::head;using EDGE::add;
int fa[N],siz[N],son[N];
void dfs1(int x){
    siz[x] = 1;
    for(int i = head[x]; i;i = edge[i].next){
        int y = edge[i].to;
        if(y == fa[x]) continue;
        fa[y] = x;dfs1(y);
        siz[x] += siz[y];
        if(siz[son[x]] < siz[y]) son[x] = y;
    }
}
int hugeson,cnt[N],ans[N],res;
inline void add(int x){
    cnt[a[x]]++;
    if(cnt[a[x]] == 1) res++;
}
inline void del(int x){
    cnt[a[x]]--;
    if(!cnt[a[x]]) res--;
}
inline void change(int x,int val){
    if(val>0) add(x);else del(x);
    for(int i = head[x]; i;i = edge[i].next){
        int y = edge[i].to;
        if(y == fa[x] || y == hugeson) continue;
        change(y,val);
    }
}
void dfs(int x,bool keep){
    for(int i = head[x]; i;i = edge[i].next){
        int y = edge[i].to;
        if(y == fa[x] || y == son[x]) continue;
        dfs(y,0);
    }
    if(son[x]) dfs(son[x],1),hugeson = son[x];
    change(x,1);ans[x] = res;hugeson = 0;
    if(!keep) change(x,-1);
}
inline void solve(){
    cin>>n;
    for(int i = 1,u,v;i < n; ++i){
        cin>>u>>v;
        add(u,v),add(v,u);
    }
    for(int i = 1;i <= n; ++i) cin>>a[i];
    dfs1(1);dfs(1,0);
    int m;cin>>m;
    for(int i = 1,q;i <= m; ++i) cin>>q,cout<<ans[q]<<'\n';
}
signed main(){
    cin.tie(nullptr)->sync_with_stdio(false);
    cout.tie(nullptr)->sync_with_stdio(false);
    solve();
}

例题

  1. Blood Cousins

    不懂为什么要打DSU On Tree,明明有更简单的做法,所以我没有打。

    但可以挂一下其他做法。

    求一下k级祖先,将题意转化成该祖先有几个k级子孙,记得答案减一

    点此查看代码
    #include<bits/stdc++.h>
    #include<bits/extc++.h>
    // using namespace __gnu_pbds;
    // using namespace __gnu_cxx;
    using namespace std;
    #define infile(x) freopen(x,"r",stdin)
    #define outfile(x) freopen(x,"w",stdout)
    #define errfile(x) freopen(x,"w",stderr)
    #ifdef LOCAL
        FILE *InFile = infile("in.in"),*OutFile = outfile("out.out");
        // FILE *ErrFile=errfile("err.err");
    #else
        FILE *Infile = stdin,*OutFile = stdout;
        //FILE *ErrFile = stderr;
    #endif
    using ll=long long;using ull=unsigned long long;
    using db = double;using ldb = long double;
    const int N = 1e5 + 10;
    int n,root,m;
    struct EDGE{int to,next;}edge[N<<1];
    int head[N],cnt;
    inline void add(int u,int v){
        edge[++cnt] = {v,head[u]};
        head[u] = cnt;
    }
    int siz[N],dep[N],fa[N],son[N],top[N],dfn[N],rdfn[N],tot,num[N];
    void dfs1(int x){
        siz[x] = 1;
        for(int i = head[x]; i;i = edge[i].next){
            int y = edge[i].to;
            if(y == fa[x]) continue;
            fa[y] = x;dfs1(y);
            siz[x] += siz[y];
            if(siz[son[x]] < siz[y]) son[x] = y;
        }
    }
    void dfs2(int x,int t){
        top[x] = t;
        dfn[x] = ++tot;
        rdfn[tot] = x;
        if(son[x]) dfs2(son[x],t);else return;
        for(int i = head[x]; i;i = edge[i].next){
            int y = edge[i].to;
            if(y == fa[x] || y == son[x]) continue;
            dfs2(y,y);
        }
    }
    inline int Get(int x,int k){
        int fx = top[x];
        while(k >= dfn[x] - dfn[fx] + 1){
            k -= dfn[x] - dfn[fx] + 1;
            x = fa[fx];
            if(!x) return 0;
            fx = top[x];
        }
        return rdfn[dfn[x]-k];
    }
    struct node{int k,id;};
    vector<node> q[N];
    int ans[N];
    void dfs(int x){
        dep[x] = dep[fa[x]] + 1,num[dep[x]]++;
        for(auto i : q[x]) ans[i.id] -= num[dep[x]+i.k];
        for(int i = head[x]; i;i = edge[i].next){
            int y = edge[i].to;
            if(y == fa[x]) continue;
            dfs(y);
        }
        for(auto i : q[x]) ans[i.id] += num[dep[x]+i.k]-1;
    }
    inline void solve(){
        cin>>n;
        for(int i = 1,fa;i <= n; ++i){
            cin>>fa;
            if(!fa) root = fa;
            else add(fa,i);
        }
        for(int i = 1;i <= n; ++i)
            if(!fa[i])
                dfs1(i),dfs2(i,i);
        cin>>m;
        for(int i = 1;i <= m; ++i){
            int v,p,fa;
            cin>>v>>p;fa = Get(v,p);
            q[fa].push_back({p,i});
        }
        for(int i = 1;i <= n; ++i)
            if(!fa[i]) dfs(i);
        for(int i = 1;i <= m; ++i) cout<<ans[i]<<' ';
    }
    signed main(){
        cin.tie(nullptr)->sync_with_stdio(false);
        cout.tie(nullptr)->sync_with_stdio(false);
        solve();    
    }
    
  2. Lomsat gelral

    这个才是正经的板子题。

    挂个代码

    点此查看代码
    #include<bits/stdc++.h>
    #include<bits/extc++.h>
    // using namespace __gnu_pbds;
    // using namespace __gnu_cxx;
    using namespace std;
    #define infile(x) freopen(x,"r",stdin)
    #define outfile(x) freopen(x,"w",stdout)
    #define errfile(x) freopen(x,"w",stderr)
    #ifdef LOCAL
        FILE *InFile = infile("in.in"),*OutFile = outfile("out.out");
        // FILE *ErrFile=errfile("err.err");
    #else
        FILE *Infile = stdin,*OutFile = stdout;
        //FILE *ErrFile = stderr;
    #endif
    using ll=long long;using ull=unsigned long long;
    using db = double;using ldb = long double;
    const int N = 1e5 + 10;
    namespace EDGE{
        struct EDGE{int to,next;}edge[N<<1];
        int head[N],cnt;
        inline void add(int u,int v){
            edge[++cnt] = {v,head[u]};
            head[u] = cnt;
        }
    }
    using EDGE::edge;using EDGE::add;using EDGE::head;
    int n,m,a[N];
    int fa[N],dfn[N],rdfn[N],dep[N],siz[N],son[N];
    void dfs1(int x){
        siz[x] = 1;
        for(int i = head[x]; i;i = edge[i].next){
            int y = edge[i].to;
            if(y == fa[x]) continue;
            fa[y] = x;dfs1(y);
            siz[x] += siz[y];
            if(siz[son[x]] < siz[y]) son[x] = y;
        }
    }
    int S,mx,cnt[N];
    ll ans[N],res;
    inline void add(int x){
        cnt[a[x]]++;
        if(cnt[a[x]] > mx) res = a[x],mx = cnt[a[x]];
        else if(cnt[a[x]] == mx) res += a[x];
    }
    inline void del(int x){cnt[a[x]]--;}
    void change(int x,int val){
        if(val > 0) add(x);else del(x);
        for(int i = head[x]; i;i = edge[i].next){
            int y = edge[i].to;
            if(y == fa[x] || y == S) continue;
            change(y,val);
        }
    }
    void dfs(int x,bool keep){
        //cerr<<x<<' '<<keep<<'\n';
        for(int i = head[x]; i;i = edge[i].next){
            int y = edge[i].to;
            if(y == fa[x] || y == son[x]) continue;
            dfs(y,0);
        }
        if(son[x]) dfs(son[x],1),S = son[x];
        change(x,1);ans[x] = res;S = 0;
        if(!keep) change(x,-1),res = mx = 0;
    }
    inline void solve(){
        cin>>n;
        for(int i = 1;i <= n; ++i) cin>>a[i];
        for(int i = 1,u,v;i < n; ++i) cin>>u>>v,add(u,v),add(v,u);
        dfs1(1);
        dfs(1,0);
        for(int i = 1;i <= n; ++i) cout<<ans[i]<<' ';
    }
    signed main(){
        cin.tie(nullptr)->sync_with_stdio(false);
        cout.tie(nullptr)->sync_with_stdio(false);
        solve();
    }
    
  3. Tree Requests

    如果一堆字符可以排列成回文串,当且仅当至多有一个字符的个数为奇数。

    然后就没了

    点此查看代码
    #include<bits/stdc++.h>
    #include<bits/extc++.h>
    // using namespace __gnu_pbds;
    // using namespace __gnu_cxx;
    using namespace std;
    #define infile(x) freopen(x,"r",stdin)
    #define outfile(x) freopen(x,"w",stdout)
    #define errfile(x) freopen(x,"w",stderr)
    #ifdef LOCAL
        FILE *InFile = infile("in.in"),*OutFile = outfile("out.out");
        // FILE *ErrFile=errfile("err.err");
    #else
        FILE *Infile = stdin,*OutFile = stdout;
        //FILE *ErrFile = stderr;
    #endif
    using ll=long long;using ull=unsigned long long;
    using db = double;using ldb = long double;
    const int N = 5e5 + 10;
    namespace EDGE{
        struct EDGE{int to,next;}edge[N<<1];
        int head[N],cnt;
        inline void add(int u,int v){
            edge[++cnt] = {v,head[u]};
            head[u] = cnt;
        }
    }
    using EDGE::edge;using EDGE::head;using EDGE::add;
    int n,m;
    char s[N];
    int siz[N],dep[N],son[N],fa[N];
    bool ans[N];
    vector<pair<int,int> > q[N];
    void dfs1(int x){
        siz[x] = 1;
        dep[x] = dep[fa[x]] + 1;
        for(int i = head[x]; i;i = edge[i].next){
            int y = edge[i].to;
            if(y == fa[x]) continue;
            fa[y] = x;dfs1(y);
            siz[x] += siz[y];
            if(siz[son[x]] < siz[y]) son[x] = y;
        }
    }
    int num[N][30],S,res;
    bool flag[N];
    void change(int x,int val){
        int k = s[x] - 'a';
        num[dep[x]][k] += val;
        for(int i = head[x]; i;i = edge[i].next){
            int y = edge[i].to;
            if(y == fa[x] || y == S) continue;
            change(y,val);
        }
    }
    inline bool check(int dep){
        int res = 0;
        for(int i = 0;i < 26; ++i) res += num[dep][i]&1;
        return res <= 1;
    }
    void dfs(int x,bool keep){
        for(int i = head[x]; i;i = edge[i].next){
            int y = edge[i].to;
            if(y == fa[x] || y == son[x]) continue;
            dfs(y,0);
        }
        if(son[x]) dfs(son[x],1),S = son[x];
        change(x,1);
        for(auto i : q[x]) ans[i.second] = check(i.first);
        S = 0;
        if(!keep) change(x,-1),res = 0;
    }
    inline void solve(){
        cin>>n>>m;
        for(int i = 2,fa;i <= n; ++i) cin>>fa,add(fa,i);
        cin>>(s+1);
        for(int i = 1,a,b;i <= m; ++i) cin>>a>>b,q[a].push_back(make_pair(b,i));
        dfs1(1);dfs(1,0);
        for(int i = 1;i <= m; ++i){
            cout<<(ans[i]?"Yes\n":"No\n");
        }
    }
    signed main(){
        cin.tie(nullptr)->sync_with_stdio(false);
        cout.tie(nullptr)->sync_with_stdio(false);
        solve();    
    }
    
  4. Arpa’s letter-marked tree and Mehrdad’s Dokhtar-k

    模拟赛的题,赛时连暴力都没打上,乐。

    听说这是DSU On Tree创始人特意出的题,已经是DSU On Tree难题了。

    先考虑\(O(n^2)\)的暴力,其实思路同上题,就是考虑回文串的性质,将其转化成异或。

    发现合法的状态只有这些种(23种)

    \[\begin{matrix} 000\dots00\\ 000\dots01\\ 000\dots10\\ \vdots\\ 100\dots00 \end{matrix}\]

    我们枚举这些合法状态即可

    每个节点记录从根节点到该点的状态,将边权转成点权。

    开一个桶,记录一下当前已经有的状态的最深的地方,注意要先统计答案再将该子树所有节点的状态加入,不然会出现错误。我就是在这里错的

    考虑到一个点与另一个点之间若能产生贡献,则贡献为\(dep_u+dep_v-2*dep_{lca}\)。然后记录最大值即可。

    说的可能不太明白,看代码应该就行了。

    (可以学习一下这篇代码dsu的写法,省去了change递归的常数,跑得更快,而且也更好调)

    点此查看代码
    #include<bits/stdc++.h>
    #include<bits/extc++.h>
    // using namespace __gnu_pbds;
    // using namespace __gnu_cxx;
    using namespace std;
    #define infile(x) freopen(x,"r",stdin)
    #define outfile(x) freopen(x,"w",stdout)
    #define errfile(x) freopen(x,"w",stderr)
    #ifdef LOCAL
        FILE *InFile = infile("in.in"),*OutFile = outfile("out.out");
        // FILE *ErrFile=errfile("err.err");
    #else
        FILE *Infile = stdin,*OutFile = stdout;
        //FILE *ErrFile = stderr;
    #endif
    using ll=long long;using ull=unsigned long long;
    using db = double;using ldb = long double;
    const int N = 5e5 + 10;
    namespace EDGE{
        struct EDGE{int to,next;}edge[N<<1];
        int head[N],cnt;
        inline void add(int u,int v){
            edge[++cnt] = {v,head[u]};
            head[u] = cnt;
        }
    }
    using EDGE::edge;using EDGE::head;using EDGE::add;
    int n,fa[N],siz[N],son[N],dep[N],pd[N];
    int L[N],R[N],tot,dfn[N],rdfn[N];
    void dfs1(int x){
        siz[x] = 1;
        L[x] = ++tot;
        rdfn[dfn[x] = tot] = x;
        dep[x] = dep[fa[x]] + 1;
        pd[x] ^= pd[fa[x]];
        for(int i = head[x]; i;i = edge[i].next){
            int y = edge[i].to;
            if(y == fa[x]) continue;
            dfs1(y);
            siz[x] += siz[y];
            if(siz[son[x]] < siz[y]) son[x] = y;
        }
        R[x] = tot;
    }
    int hugeson,ans[N],have[1<<23];
    vector<int> state;
    void dfs(int x,bool keep){
        for(int i = head[x]; i;i = edge[i].next){
            int y = edge[i].to;
            if(y == fa[x] || y == son[x]) continue;
            dfs(y,0);ans[x] = max(ans[x],ans[y]);
        }
        if(son[x]) dfs(son[x],1),ans[x] = max(ans[x],ans[son[x]]);
        if(have[pd[x]]) ans[x] = max(ans[x],have[pd[x]] - dep[x]);
        for(auto ok : state)
            if(have[pd[x]^ok]) ans[x] = max(ans[x],have[pd[x]^ok]-dep[x]);
        have[pd[x]] = max(dep[x],have[pd[x]]);
        for(int i = head[x]; i;i = edge[i].next){
            int y = edge[i].to;
            if(y == son[x]) continue;
            for(int j = L[y];j <= R[y]; ++j){
                int k = rdfn[j];
                if(have[pd[k]]) ans[x] = max(ans[x],have[pd[k]]+dep[k] - 2*dep[x]);
                for(auto ok : state)
                    if(have[pd[k]^ok])
                        ans[x] = max(ans[x],have[pd[k]^ok] + dep[k] - 2*dep[x]);
            }
            for(int j = L[y];j <= R[y]; ++j) 
                have[pd[rdfn[j]]] = max(have[pd[rdfn[j]]],dep[rdfn[j]]);
        }
        if(!keep) for(int i = L[x];i <= R[x]; ++i) have[pd[rdfn[i]]] = 0;
    }
    inline void solve(){
        state.emplace_back(0);
        for(int i = 0;i < 22; ++i) state.push_back(1<<i);
        cin>>n;
        for(int i = 2;i <= n; ++i){
            char x;
            cin>>fa[i]>>x;
            pd[i] = 1<<(x-'a');
            add(fa[i],i);
        }
        dfs1(1);dfs(1,0);
        for(int i = 1;i <= n; ++i) cout<<ans[i]<<' ';
    }
    signed main(){
        cin.tie(nullptr)->sync_with_stdio(false);
        cout.tie(nullptr)->sync_with_stdio(false);
        solve();    
    }
    
posted @ 2024-08-09 19:08  CuFeO4  阅读(11)  评论(0编辑  收藏  举报