模板 - 图论 - 树链剖分

今天来啃一下这个树剖吧。

 

模板题是要求这四个问题:

  • 将树从x到y结点最短路径上所有节点的值都加上z
  • 求树从x到y结点最短路径上所有节点的值之和
  • 将以x为根节点的子树内所有节点值都加上z
  • 求以x为根节点的子树内所有节点值之和

http://www.cnblogs.com/zwfymqz/p/8094500.html

 


 

 

这里的边要开两倍,线段树当然要开四倍。

第一次模板:

#include<bits/stdc++.h>
using namespace std;
#define ll long long

/* 树链剖分 begin*/

#define MAXN 100100
#define ls k<<1
#define rs k<<1|1

struct Edge{
    int u,v,nxt;
}edge[MAXN*2];

int head[MAXN];
int num=1;

void addedge(int u,int v){
    edge[num].u=u;
    edge[num].v=v;
    edge[num].nxt=head[u];
    head[u]=num++;
}

int deep[MAXN];//节点的深度
int fa[MAXN];//节点的父亲
int son[MAXN];//节点的重儿子
int tot[MAXN];//节点子树的大小
int top[MAXN];
int idx[MAXN];

int cnt=0;
int a[MAXN];
int b[MAXN];
int MOD;

int dfs1(int now, int f, int dep) {
    deep[now] = dep;
    fa[now] = f;
    tot[now] = 1;
    int maxson = -1;
    for (int i = head[now]; i != -1; i = edge[i].nxt) {
        if (edge[i].v == f) continue;
        tot[now] += dfs1(edge[i].v, now, dep + 1);
        if (tot[edge[i].v] > maxson) maxson = tot[edge[i].v], son[now] = edge[i].v;
    }
    return tot[now];
}

void dfs2(int now, int topf) {
    idx[now] = ++cnt;
    a[cnt] = b[now];
    top[now] = topf;
    if (!son[now]) return ;
    dfs2(son[now], topf);
    for (int i = head[now]; i != -1; i = edge[i].nxt)
        if (!idx[edge[i].v])
            dfs2(edge[i].v, edge[i].v);
}

struct Tree {
    int l, r, w, siz, f;
} T[MAXN*4];

void update(int k) { //更新
    T[k].w = (T[ls].w + T[rs].w + MOD) % MOD;
}

void Build(int k, int l, int r) {
    T[k].l = l; T[k].r = r; T[k].siz = r - l + 1;
    if (l == r) {
        T[k].w = a[l];
        return ;
    }
    int mid = (l + r) >> 1;
    Build(ls, l, mid);
    Build(rs, mid + 1, r);
    update(k);
}

void pushdown(int k) { //下传标记
    if (!T[k].f) return ;
    T[ls].w = (T[ls].w + T[ls].siz * T[k].f) % MOD;
    T[rs].w = (T[rs].w + T[rs].siz * T[k].f) % MOD;
    T[ls].f = (T[ls].f + T[k].f) % MOD;
    T[rs].f = (T[rs].f + T[k].f) % MOD;
    T[k].f = 0;
}

void IntervalAdd(int k, int l, int r, int val) { //区间加
    if (l <= T[k].l && T[k].r <= r) {
        T[k].w += T[k].siz * val;
        T[k].f += val;
        return ;
    }
    pushdown(k);
    int mid = (T[k].l + T[k].r) >> 1;
    if (l <= mid)    IntervalAdd(ls, l, r, val);
    if (r > mid)    IntervalAdd(rs, l, r, val);
    update(k);
}

int IntervalSum(int k, int l, int r) { //区间求和
    int ans = 0;
    if (l <= T[k].l && T[k].r <= r)
        return T[k].w;
    pushdown(k);
    int mid = (T[k].l + T[k].r) >> 1;
    if (l <= mid) ans = (ans + IntervalSum(ls, l, r)) % MOD;
    if (r > mid)  ans = (ans + IntervalSum(rs, l, r)) % MOD;
    return ans;
}

int TreeSum(int x, int y) { //求x与y路径上的和
    int ans = 0;
    while (top[x] != top[y]) {
        if (deep[top[x]] < deep[top[y]]) swap(x, y);
        ans = (ans + IntervalSum(1, idx[ top[x] ], idx[x])) % MOD;
        x = fa[ top[x] ];
    }
    if (deep[x] > deep[y]) swap(x, y);
    ans = (ans + IntervalSum(1, idx[x], idx[y])) % MOD;
    return ans;
}

void TreeAdd(int x, int y, int val) { //对于x,y路径上的点加val的权值
    while (top[x] != top[y]) {
        if (deep[top[x]] < deep[top[y]]) swap(x, y);
        IntervalAdd(1, idx[ top[x] ], idx[x], val);
        x = fa[ top[x] ];
    }
    if (deep[x] > deep[y])    swap(x, y);
    IntervalAdd(1, idx[x], idx[y], val);
}

//求x所在子树的和
//IntervalSum(1,idx[x],idx[x]+tot[x]-1);
//对x所在子树的点加val的权值
//IntervalAdd(1,idx[x],idx[x]+tot[x]-1,val%MOD);

/* 树链剖分 end */

int main(){
    memset(head,-1,sizeof(head));

    int n,m,r,p;
    cin>>n>>m>>r>>p;
    MOD=p;

    for(int i=1;i<=n;i++){
        cin>>b[i];
    }
    for(int i=1;i<=n-1;i++){
        int u,v;
        cin>>u>>v;
        addedge(u,v);
        addedge(v,u);
    }

    dfs1(r,0,1);
    dfs2(r,r);

    Build(1,1,n);
    while(m--){
        int opt,x,y,z;
        cin>>opt;
        switch(opt){
        case 1:
            cin>>x>>y>>z;
            TreeAdd(x,y,z%p);
            break;
        case 2:
            cin>>x>>y;
            cout<<TreeSum(x,y)<<endl;
            break;
        case 3:
            cin>>x>>z;
            IntervalAdd(1,idx[x],idx[x]+tot[x]-1,z%p);
            break;
        case 4:
            cin>>x;
            cout<<IntervalSum(1,idx[x],idx[x]+tot[x]-1)<<endl;
        }

    }
}
View Code

 

posted @ 2019-04-04 00:12  韵意  阅读(135)  评论(0编辑  收藏  举报