树链剖分 模板

易错点

注意两次dfs里记录信息时要核对一下是不是记录正确

然后就是线段树里一堆 += 不要写成 =

Code

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;

const int MAXN = 1e5 + 5;
const int MAXM = 1e5 + 5;
const int INF = 0x3f3f3f3f;

int n,m;
int root,p,tot;
struct Edge {
    int nxt;
    int to,w;
} l[MAXN << 1];
struct Tree {
    int sum;
    int tag;
} t[MAXN << 2];
int deep[MAXN],fa[MAXN],siz[MAXN],son[MAXN];
int head[MAXN],cnt;
int id[MAXN],val[MAXN],w[MAXN],top[MAXN];

void add(int x,int y) {
    cnt++;
    l[cnt].nxt = head[x];
    l[cnt].to = y;
    head[x] = cnt;
    return;
}

void dfs1(int x,int from) {
    deep[x] = deep[from] + 1;
    fa[x] = from;
    siz[x] = 1;
    int maxsiz = -INF;
    for(int i = head[x]; i; i = l[i].nxt) {
        if(l[i].to == from) continue;
        dfs1(l[i].to,x);
        siz[x] += siz[l[i].to];
        if(siz[l[i].to] > maxsiz) {
            maxsiz = siz[l[i].to];
            son[x] = l[i].to;
        }
    }
    return;
}

void dfs2(int x,int y,int from) {
    id[x] = ++tot;
    val[tot] = w[x];
    top[x] = y;
    if(!son[x]) return;
    dfs2(son[x],y,x);
    for(int i = head[x]; i; i = l[i].nxt) {
        if(l[i].to == from || l[i].to == son[x]) continue;
        dfs2(l[i].to,l[i].to,x);
    }
    return;
}

void update(int pos) {
    t[pos].sum = (t[pos << 1].sum + t[pos << 1 | 1].sum) % p;
    return;
}

void pushdown(int L,int R,int pos) {
    if(!t[pos].tag) return;
    int mid = (L + R) >> 1;
    t[pos << 1].sum += t[pos].tag * (mid - L + 1);
    t[pos << 1 | 1].sum += t[pos].tag * (R - mid);
    t[pos << 1].tag += t[pos].tag;
    t[pos << 1 | 1].tag += t[pos].tag;
    t[pos].tag = 0;
    return;
}

void build(int L,int R,int pos) {
    if(L == R) {
        t[pos].sum = val[L] % p;
        return;
    }
    int mid = (L + R) >> 1;
    build(L,mid,pos << 1);
    build(mid + 1,R,pos << 1 | 1);
    update(pos);
    return;
}

void modify(int L,int R,int ll,int rr,int pos,int v) {
    if(ll <= L && R <= rr) {
        t[pos].sum += v * (R - L + 1);
        t[pos].tag += v;
        return;
    }
    if(R < ll || rr < L) return;
    int mid = (L + R) >> 1;
    pushdown(L,R,pos);
    modify(L,mid,ll,rr,pos << 1,v);
    modify(mid + 1,R,ll,rr,pos << 1 | 1,v);
    update(pos);
    return;
}

int query(int L,int R,int ll,int rr,int pos) {
    if(ll <= L && R <= rr) {
        return t[pos].sum % p;
    }
    if(R < ll || rr < L) return 0;
    int mid = (L + R) >> 1;
    pushdown(L,R,pos);
    return query(L,mid,ll,rr,pos << 1) + query(mid + 1,R,ll,rr,pos << 1 | 1);
}

void way_add(int x,int y,int z) {
    while(top[x] != top[y]) {
        if(deep[top[x]] < deep[top[y]]) swap(x,y);
        modify(1,n,id[top[x]],id[x],1,z % p);
        x = fa[top[x]];
    }
    if(deep[x] > deep[y]) swap(x,y);
    modify(1,n,id[x],id[y],1,z % p);
    return;
}

int way_ask(int x,int y) {
    int ans = 0;
    while(top[x] != top[y]) {
        if(deep[top[x]] < deep[top[y]]) swap(x,y);
        ans += query(1,n,id[top[x]],id[x],1);
        ans %= p;
        x = fa[top[x]];
    }
    if(deep[x] > deep[y]) swap(x,y);
    ans += query(1,n,id[x],id[y],1);
    return ans % p;
}

void son_add(int x,int y) {
    modify(1,n,id[x],id[x] + siz[x] - 1,1,y % p);
}

int son_ask(int x) {
    return query(1,n,id[x],id[x] + siz[x] - 1,1);
}

int main() {
    scanf("%d%d%d%d",&n,&m,&root,&p);
    for(int i = 1; i <= n; i++) scanf("%d",&w[i]);
    int x,y,z;
    for(int i = 1; i < n; i++) {
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x);
    }
    dfs1(root,0);
    dfs2(root,root,0);
    build(1,n,1);
    
//    for(int i = 1;i <= n;i++) {
//        cout<<"DEBUG:"<<top[i]<<" "<<fa[i]<<" "<<deep[i]<<" "<<id[i]<<" "<<val[i]<<endl;
//    }
    /*
    操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
    操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和
    操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z
    操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和
    */
    int opt;
    while(m--) {
        scanf("%d",&opt);
        if(opt == 1) {
            scanf("%d%d%d",&x,&y,&z);
            way_add(x,y,z);
        } else if(opt == 2) {
            scanf("%d%d",&x,&y);
            printf("%d\n",way_ask(x,y) % p);
        } else if(opt == 3) {
            scanf("%d%d",&x,&y);
            son_add(x,y);
        } else if(opt == 4) {
            scanf("%d",&x);
            printf("%d\n",son_ask(x) % p);
        }
    }
    return 0;
}

 

posted @ 2018-08-16 11:01  Floatiy  阅读(229)  评论(2编辑  收藏  举报