树链剖分总结

树链剖分

树链剖分是搞啥的?

一种类似于原子崩坏的超能力。

先把树炸成好多好多条链。然后把这些链组装起来,以进行路径修改,路径查询。

几个概念

  • 重儿子,轻儿子
  • 轻边
  • 重边,重链

请尽情发挥艺术天赋,动手画棵树。

两个看起来没什么用的事实

  • vu的一个轻儿子: size[v]*2 < size[u]

  • 在DFS的时候,如果我们先采访重儿子,那么重链的DFS序连续。

一个看起来还是没什么用的事实

x到根的路径中,重链,轻边的条数都是log(n)级别的

假的证明:

我们从x向根节点走。

根据上面那条性质: 树的总结点个数 <= pow(2, x到根,轻边的个数)

所以x到根路径上,轻边的条树,是log级别的。

而,两条重链之间,至少有一条轻边。所以重链的条数,也是log级别的。

似乎还是有那么一点点用的

推广一下上面的结论。

( ⊙ o ⊙ )!任意两点之间的路径,重链,轻边的条数都是log(n)级别的哎!

只要我们要维护的信息满足区间可合并【区间最值,区间和,区间GCD】

我们就可以把一条链,分成log条。

又每一条链DFS序连续。因此我们可以用线段树维护每一条链。


容易搞错的地方

  • dfn[x]x区分开、
  • 在找两个点的LCA的时候,top深度高的先往上爬。

几个栗子

POJ3237

题意:区间更新,区间取反,区间最值

我发明了一种神妙的战法。在push_up的时候把懒惰标记也传上去了。

调代码的时候,

WA! 懒惰标记上天啦!

我的内心是崩溃的.....

所以,我是来搞笑的吗?

code

#include <iostream>
#include <vector>
#include <cstdio>
using namespace std;
typedef pair<int,int> pii;
const int N = 200000+10;
const int INF = 1e9+7;

vector<pii> g[N];
int edge_id[N];
int size[N];
int dep[N];
int son[N], par[N], dis[N];
int dfn[N], moment, who[N];
int top[N];

int T, n;
int u[N], v[N], w[N];

void init() {
    moment = 0;
    for(int i=0;i<N;i++)
        g[i].clear();
}

void presolve(int u,int p) {
    size[u] = 1;
    par[u] = p;
    int maxSize=0, id=-1;
    for(int i=0;i<g[u].size();i++) {
        int v=g[u][i].first;
        if(v==p) continue;
        dis[v] = g[u][i].second;
        dep[v] = dep[u]+1;
        
        presolve(v, u);
        size[u] += size[v];
        if(size[v] > maxSize)
            maxSize = size[v], id = v;
    }
    son[u] = id;
}

void dfs(int u,int p) {
    dfn[u] = ++moment;
    who[moment] = u;
    top[u] = p;
    //printf("u = %d, p = %d \n", u, p);
    if (son[u] != -1) {
        //printf("%d -> %d\n",u, son[u]);
        dfs(son[u], p);
    }
    for(int i=0;i<g[u].size();i++) {
        int v=g[u][i].first;
        if(v==son[u] || v==par[u]) continue;
        dfs(v, v);
    }
}

struct Data {
    int l, r;
    int mx, mn;
    int ne;
    Data operator + (const Data & o) const {
        Data ans = o; ans.ne = 0;
        ans.mx = max(o.mx, mx);
        ans.mn = min(o.mn, mn);
        return ans;
    }
} nod[N<<2]; 

void build(int l,int r,int rt) {
    nod[rt].l = l, nod[rt].r = r; nod[rt].ne = 0;

    if (l==r) {
        nod[rt].mn = nod[rt].mx = dis[who[l]];
        return;
    }
    int mid=(l+r)>>1;
    build(l,mid,rt<<1);
    build(mid+1,r,rt<<1|1);
    nod[rt] = nod[rt<<1] + nod[rt<<1|1];
}

void push_down(int rt) {
    if (nod[rt].ne) {
        //printf("##### rt = %d\n", rt);
        nod[rt<<1].mx = -nod[rt<<1].mx;
        nod[rt<<1].mn = -nod[rt<<1].mn;
        swap(nod[rt<<1].mx, nod[rt<<1].mn);

        nod[rt<<1|1].mx = -nod[rt<<1|1].mx;
        nod[rt<<1|1].mn = -nod[rt<<1|1].mn;
        swap(nod[rt<<1|1].mx, nod[rt<<1|1].mn);

        nod[rt<<1].ne ^= 1;
        nod[rt<<1|1].ne ^= 1;

        nod[rt].ne = 0;
    }
}

int query(int l,int r,int rt,int L,int R) {
    //printf("gg\n");
    //printf("! %d %d %d %d\n", l,r,L,R);
    if (L<=l&&r<=R) {
        //printf("rt=%d, [%d, %d] mx=%d\n", rt,l,r,nod[rt].mx);
        return nod[rt].mx;
    }
    //if (rt!=8)
    push_down(rt);
    int mid = (l+r)>>1;
    int ans = - INF;
    if (L <= mid) ans = query(l,mid,rt<<1,L,R);
    if (R  > mid) ans = max(ans, query(mid+1,r,rt<<1|1,L,R));
    //printf("%d %d %d %d %d\n", l,r,L,R,ans);
    return ans; 
}

void update(int l,int r,int rt,int pos,int x) {
    if (l == r) {
        nod[rt].mx = nod[rt].mn = x;
        nod[rt].ne = 0;
        return;
    }
    //if (rt!=8)
    push_down(rt);
    int mid = (l+r)>>1;
    if (pos <= mid)
        update(l,mid,rt<<1,pos,x);
    else
        update(mid+1,r,rt<<1|1,pos,x);

    nod[rt] = nod[rt<<1] + nod[rt<<1|1];
}

void neg_update(int l,int r,int rt,int L,int R) {
    //printf("%d %d %d %d\n", l,r, L,R);
    if (L<=l&&r<=R) {
        nod[rt].ne ^= 1;
        nod[rt].mn *= -1;
        nod[rt].mx *= -1;

        //printf("rt=%d [%d, %d] %d %d, ne=%d\n", rt,nod[rt].l, nod[rt].r, nod[rt].mn, nod[rt].mx, nod[rt].ne);
        swap(nod[rt].mn, nod[rt].mx);
        return;
    }
    //if(rt!=8)
    push_down(rt);
    int mid = (l+r)>>1;
    if (L <= mid)
        neg_update(l,mid,rt<<1,L,R);
    if (R  > mid)
        neg_update(mid+1,r,rt<<1|1,L,R);

    nod[rt] = nod[rt<<1] + nod[rt<<1|1];
}

int max_on_path(int u, int v) {
    int ans = - INF;
    int f1 = top[u], f2 = top[v];
    while (f1 != f2) {
        //printf("%d %d\n", f1, f2);
        if (dep[f1] < dep[f2]) {
            swap(f1, f2);
            swap(u, v);
        }
        //printf("[%d, %d]\n", dfn[f1], dfn[u]);
        ans = max(ans, query(1,n,1,dfn[f1],dfn[u]));
        //printf("%d\n", ans);
        u = par[f1]; f1 = top[u];
        //printf("%d %d %d\n", u, f1, f2);
    }

    if (u == v) return ans;
    if (dep[u] > dep[v]) swap(u, v);
    ans = max(ans, query(1,n,1,dfn[son[u]],dfn[v]));

    return ans;
}

void update_on_tree(int u, int v) {
    int f1 = top[u], f2 = top[v];
    while (f1 != f2) {
        if (dep[f1] < dep[f2]) {
            swap(f1, f2);
            swap(u, v);
        }
        neg_update(1,n,1,dfn[f1],dfn[u]);
        u = par[f1]; f1 = top[u];
    }
    if (u == v) return;
    if (dep[u] > dep[v]) swap(u, v);
    neg_update(1,n,1,dfn[son[u]], dfn[v]);
}

int main() {
    scanf("%d", &T);
    while (T --) {
        scanf("%d", &n);
        init();

        for (int i=1;i<n;i++) {
            scanf("%d %d %d", &u[i], &v[i], &w[i]);
            g[u[i]].push_back(make_pair(v[i], w[i]));
            g[v[i]].push_back(make_pair(u[i], w[i]));
        }
        
        presolve(1, 1);
        dfs(1, 1);

        for (int i=1;i<n;i++) {
            if (par[u[i]] == v[i])
                edge_id[i] = dfn[u[i]];
            else
                edge_id[i] = dfn[v[i]];
        }

        /*
        for(int i=1;i<=n;i++) {
            printf("i = %d\n", i);
            printf("dfn = %d, son = %d, par = %d, dis = %d, top = %d\n", dfn[i],son[i],par[i],dis[i],top[i]);
        }
        */
        
        build(1, n, 1);

        char op[5]; int a, b;
        while (scanf("%s", op)) {
            if (op[0] == 'D') break;
            scanf("%d %d", &a, &b);
            if (op[0] == 'C') {
                update(1,n,1,edge_id[a],b);
            }
            if (op[0] == 'Q') {
                int ans = max_on_path(a, b);
                if (a == b) ans = 0;
                printf("%d\n", ans);
            }   
            if (op[0] == 'N') {
                update_on_tree(a, b);
            }
        }
    }
}



Gym 101741 C

题意

给出一棵树,与m条路径,选出最小的点集,使得每条路径都包含点集中的点。

做法

  • 在一条链上,就是个按右端点排序的贪心。

  • 把路径按照LCA的深度从高到低排序。

  • 然后遍历所有路径,如果路径上没有选择的点,那就选择LCA。否则,什么都不做。这个可以通过单点更新,区间查询来实现。

code

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 400000+10;
int n;
int son[N],size[N],par[N],top[N],dep[N];
int dfn[N],who[N],moment;
vector<int> g[N];
void preSovle(int u,int p) {
    size[u]=1;
    par[u] =p;
    dep[u] =dep[p]+1;
    int mx=0, bst=-1;
    for(int i=0;i<g[u].size();i++) {
        int v=g[u][i];
        if (v==p) continue;
        preSovle(v, u);
        size[u] += size[v];
        if (size[v] > mx) {
            mx = size[v];
            bst = v;
        }
    }
    son[u] = bst;
}
void dfs(int u,int p) {
    top[u]=p;
    dfn[u]=++moment;
    who[moment]=u;
    if(son[u]!=-1){
        dfs(son[u],p);
    }
    for(int i=0;i<g[u].size();i++) {
        int v=g[u][i];
        if(v!=par[u]&&v!=son[u])
            dfs(v,v);
    }
}

struct Query {
    int u,v;
    int lca;
    bool operator < (const Query & o) const {
        return dep[lca] > dep[o.lca];
    }
} q[N];
int sum[N<<2];

void update(int l,int r,int rt,int pos){
    if(l==r) {
        sum[rt] ++;
        return;
    }
    int mid=(l+r)>>1;
    if (pos<=mid) update(l,mid,rt<<1,pos);
    else update(mid+1,r,rt<<1|1,pos);
    sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}
int query(int l,int r,int rt,int L,int R){
    if(L<=l&&r<=R) {
        return sum[rt];
    }
    int mid=(l+r)>>1;
    int ans=0;
    if(L<=mid) ans+=query(l,mid,rt<<1,L,R);
    if(R >mid) ans+=query(mid+1,r,rt<<1|1,L,R);
    return ans;
}

pair<int,int> sum_path(int u,int v) { // (lca, sum)
    int f1=top[u], f2=top[v];
    int ans=0;
    while(f1!=f2) {
        //printf("%d %d\n", f1, f2);
        if (dep[f1] < dep[f2]) {
            swap(f1,f2);
            swap(u,v);
        }
        ans+=query(1,n,1,dfn[f1],dfn[u]);
        u=par[f1], f1=top[u];
        //printf("%d %d\n", f1, f2);
    }

    if(dep[u]>dep[v]) swap(u,v);
    ans+=query(1,n,1,dfn[u],dfn[v]);
    return make_pair(u, ans);
}

int main() {
    scanf("%d",&n);
    for(int i=1;i<n;i++) {
        int u,v;
        scanf("%d%d",&u,&v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    preSovle(1,1);

    dfs(1,1);

    int m; scanf("%d", &m);
    for(int i=1;i<=m;i++) {
        scanf("%d %d", &q[i].u, &q[i].v);
        q[i].lca = sum_path(q[i].u,q[i].v).first;
        //printf("ok\n");
    }
    sort(q+1,q+1+m);
    vector<int> ret;
    for(int i=1;i<=m;i++) {
        //printf("%d\n", q[i].lca);
        if(sum_path(q[i].u, q[i].v).second == 0) {
            ret.push_back(q[i].lca);
            update(1,n,1,dfn[q[i].lca]);
        }   
    }

    printf("%d\n", ret.size());
    for(auto x: ret) {
        printf("%d ", x);
    }
}


BZOJ2243

题意

一棵树,路径赋值,路径颜色段数查询。

1121:有三段

1121233:有五段

题解

线段树维护:

  • 区间左端的颜色
  • 区间右端的颜色
  • 区间内总共有多少段颜色

信息可以合并哎!

code

写残了...

#include <iostream>
#include <cstdio>
#include <vector>
using namespace std;
const int N = 400000+10;
vector<int> g[N];
int a[N],son[N],par[N],dep[N],sz[N],top[N],dfn[N],who[N],moment;
int n, m;
void init() {
    for(int i=0;i<N;i++) 
        g[i].clear();
    moment=0;
}
void pre(int u,int p){
    par[u]=p,dep[u]=dep[p]+1,sz[u]=1;
    int hson=-1,mx=0;
    for(int i=0;i<g[u].size();i++) {
        int v=g[u][i]; if(v==p) continue;
        pre(v,u);
        if(sz[v]>mx) mx=sz[v], hson=v;
        sz[u]+=sz[v];
    }
    son[u]=hson;
}
void dfs(int u,int p){
    top[u]=p;
    dfn[u]=++moment, who[moment]=u;
    if(son[u]!=-1) dfs(son[u],p);
    for(int i=0;i<g[u].size();i++) {
        int v=g[u][i]; if(v==son[u]||v==par[u]) continue;
        dfs(v,v);
    }
}
struct Data {
    int lef;
    int rig;
    int tot;
    int lazy;
    Data operator + (const Data & o) const {
        Data ret;
        ret.lef = lef;  ret.rig = o.rig;    
        ret.tot = (rig == o.lef) ? (tot+o.tot-1) : (tot+o.tot);
        ret.lazy = 0;
        return ret;
    }
} nod[N<<2];
Data rev(Data nod) {
    swap(nod.lef, nod.rig);
    return nod;
}
Data set_value(int rt, int x) {
    Data ans;
    ans.lef = ans.rig = x;
    ans.tot = 1; ans.lazy = 0;
    return ans;
}
void push_down(int rt) {
    if (nod[rt].lazy) {
        nod[rt<<1] = set_value(rt<<1, nod[rt].lazy);  
        nod[rt<<1].lazy = nod[rt].lazy;
        nod[rt<<1|1]=set_value(rt<<1|1,nod[rt].lazy);
        nod[rt<<1|1].lazy = nod[rt].lazy;
        nod[rt].lazy  = 0;
    }
}
void build(int l,int r,int rt) {
    nod[rt].lazy = 0;
    if (l==r) {
        nod[rt].lef = nod[rt].rig = a[who[l]];
        nod[rt].tot = 1;
        return;
    }
    int mid = (l+r)>>1;
    build(l,mid,rt<<1);
    build(mid+1,r,rt<<1|1);
    nod[rt] = nod[rt<<1] + nod[rt<<1|1];
}
void update(int l,int r,int rt,int L,int R,int x) {
    if (L<=l&&r<=R) {
        nod[rt]=set_value(rt, x);
        nod[rt].lazy = x;
        return;
    }
    push_down(rt);
    int mid = (l+r)>>1;
    if (L<=mid) update(l,mid,rt<<1,L,R,x);
    if (R >mid) update(mid+1,r,rt<<1|1,L,R,x);
    nod[rt] = nod[rt<<1] + nod[rt<<1|1];
}
Data query(int l,int r,int rt,int L,int R) {
    if (L<=l&&r<=R) {
        return nod[rt];
    }
    push_down(rt);
    int mid=(l+r)>>1;
    if (L<=mid && R<=mid) return query(l,mid,rt<<1,L,R);
    if (L>mid && R>mid) return query(mid+1,r,rt<<1|1,L,R);
    return query(l,mid,rt<<1,L,R) + query(mid+1,r,rt<<1|1,L,R);
}
int nex[N];
int get_path(int u, int v) {

    int cu = u, cv = v;
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u,v), swap(cu, cv);
        Data tmp = query(1,n,1,dfn[top[u]], dfn[u]);
        nex[u] = top[u]; 
        u = par[top[u]];
    }
    if (dep[u] > dep[v]) swap(u, v), swap(cu, cv);
    nex[v] = u; 
    
    Data ans, ans_; 
    bool find = 0, find_ = 0;
    while(cu != u) {
        if (find == 1)
            ans = query(1,n,1,dfn[nex[cu]], dfn[cu]) + ans;
        else
            find = 1, ans = query(1,n,1,dfn[nex[cu]],dfn[cu]);
        cu = par[top[cu]];
    }
    ans = rev(ans);
    if (find == 1)
        ans = ans + query(1,n,1,dfn[u],dfn[v]);
    else
        find = 1, ans = query(1,n,1,dfn[u],dfn[v]); 
    
    while(cv != v) {
        if (find_ == 1)
            ans_ = query(1,n,1,dfn[nex[cv]], dfn[cv]) + ans_;
        else
            find_ = 1, ans_ = query(1,n,1,dfn[nex[cv]],dfn[cv]);
        cv = par[top[cv]];
    }
    
    if (find && find_)
        ans = ans + ans_;
    else if (find_)
        ans = ans_;
    else if (find)
        ans = ans;

    return ans.tot;
}
void modidy(int u,int v,int x) {
    while(top[u]!=top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u,v);
        update(1,n,1,dfn[top[u]],dfn[u],x);
        u=par[top[u]];
    }
    if(dep[u]>dep[v]) swap(u,v);
    update(1,n,1,dfn[u],dfn[v],x);
}
int main() {
    while (~ scanf("%d %d", &n, &m)) {
        init();
        for (int i=1;i<=n;i++)
            scanf("%d", &a[i]);
        for (int i=1;i<n;i++) {
            int u, v; 
            scanf("%d%d",&u,&v);
            g[u].push_back(v);
            g[v].push_back(u);
        }
        pre(1,1);
        dfs(1,1);
        build(1,n,1);

        for(int i=1;i<=m;i++) {
            char op[2]; int u, v, x;
            scanf("%s%d%d",op,&u,&v);
            if (op[0]=='Q') {
                printf("%d\n", get_path(u,v));
            } else {
                scanf("%d", &x); modidy(u,v,x);
            }
        }
    }
}
posted @ 2018-05-28 11:17  RUSH_D_CAT  阅读(198)  评论(0编辑  收藏  举报