P3383 树链剖分模板题

#include<bits/stdc++.h>
using namespace std;
#define rep(i, a, n) for(int i = a; i <= n; ++ i);
#define per(i, a, n) for(int i = n; i >= a; -- i);
typedef long long ll;
const int N = 2e6+ 5;
// const ll mod = 1e9 + 7;
int mod;
const double Pi = acos(- 1.0);
const int INF = 0x3f3f3f3f;
const int G = 3, Gi = 332748118;
ll qpow(ll a, ll b) { ll res = 1; while(b){ if(b & 1) res = (res * a) % mod; a = (a * a) % mod; b >>= 1;} return res; }
ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; }
ll lcm(ll a, ll b) { return a * b / gcd(a, b);}
bool cmp(int a, int b){ return a > b;}
//


int n, m, r;
int head[N], cnt = 0;
struct node{
    int to, nxt, c;
}edge[N << 1];

struct Tree{
    int l, r, val, lz;
}tree[N * 4];
int val[N], tval[N];
int son[N], siz[N],dfn[N], dep[N], top[N], fa[N], rnk[N];
int res = 0, tot = 0;

void add(int u, int v){
    edge[cnt].to = v, edge[cnt].nxt = head[u], head[u] = cnt ++;
    edge[cnt].to = u, edge[cnt].nxt = head[v], head[v] = cnt ++;
}

void pushdown(int index){
    if(tree[index].lz){
        tree[index << 1].val += (tree[index << 1].r - tree[index << 1].l + 1) * tree[index].lz % mod;
        tree[index << 1 | 1].val += (tree[index<<1|1].r - tree[index<<1|1].l + 1) * tree[index].lz % mod;
        tree[index << 1].lz += tree[index].lz;
        tree[index << 1 | 1].lz += tree[index].lz;
        tree[index].lz = 0;
    }
}

void pushup(int index){
    tree[index].val = (tree[index << 1].val + tree[index << 1 | 1].val) % mod;
}

void Build(int l, int r, int index){
    tree[index].l = l, tree[index].r = r;
    tree[index].lz = 0;
    if(l == r){
        tree[index].val = tval[l] % mod;
        return;
    }
    int mid = (l + r) >> 1;
    Build(l, mid, index << 1);
    Build(mid + 1, r, index << 1 | 1);
    pushup(index);
}

void updata(int l, int r, int index, int val){
    if(tree[index].l >= l && tree[index].r <= r){
        tree[index].lz += val;
        tree[index].val += (tree[index].r - tree[index].l + 1) * val;
        tree[index].val %= mod;
        return;
    }
    if(tree[index].lz) pushdown(index);
    int mid = (tree[index].l + tree[index].r) >> 1;
    if(l <= mid) updata(l, r, index << 1, val);
    if(r > mid) updata(l, r, index << 1 | 1, val);
    pushup(index);
}

int query(int l, int r, int index){
    if(l <= tree[index].l && tree[index].r <= r){
        return tree[index].val % mod;
    }
    if(tree[index].lz) pushdown(index);
    int mid = (tree[index].l + tree[index].r) >> 1;
    int ans = 0;
    if(l <= mid) ans += query(l, r, index << 1);
    if(r > mid) ans += query(l, r, index << 1 | 1);
    return ans % mod;
}
// --------------------------------

int qRange(int x, int y){ //x 到 y树上最短路径结点权值和
    int res = 0;
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        res += query(dfn[top[x]], dfn[x], 1);
        x = fa[top[x]];
    }
    if(dep[x] > dep[y]) swap(x, y);
    res += query(dfn[x], dfn[y], 1);
    return res % mod;
}

void updRange(int x, int y, int c){ //x 到 y最短路径上点值 + z
    c %= mod;
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        updata(dfn[top[x]], dfn[x], 1, c);
        x = fa[top[x]];
    }
    if(dep[x] > dep[y]) swap(x, y);
    updata(dfn[x], dfn[y], 1, c);
}


int qSon(int x){ //以x为根结点的子树内所有节点值之和
    return query(dfn[x], dfn[x] + siz[x] - 1, 1);
}


void updSon(int x, int val){ //以x为根的子树内所有节点值 + z
    updata(dfn[x], dfn[x] + siz[x] - 1, 1, val);
}


void dfs1(int u, int pre){
    dep[u] = dep[pre] + 1;
    fa[u] = pre;
    siz[u] = 1;
    int maxx = -1;
    for(int i = head[u]; i != -1; i = edge[i].nxt){
        int v = edge[i].to;
        if(v == pre) continue;
        dfs1(v, u);
        siz[u] += siz[v];
        if(siz[v] > maxx){
            maxx = siz[v];
            son[u] = v;
        }
    }
}

void dfs2(int u, int topu){ //topu当前链的最顶端的节点
    dfn[u] = ++ tot;
    tval[tot] = val[u];
    top[u] = topu;
    rnk[tot] = u;
    if(!son[u]) return;
    dfs2(son[u], topu);
    for(int i = head[u]; i != -1; i = edge[i].nxt){
        int v = edge[i].to;
        if(v == son[u] || v == fa[u]) continue;
        dfs2(v, v);
    }
}

int main()
{
    scanf("%d%d%d%d",&n, &m, &r, &mod);
    cnt = 0; head[0] = -1;
    for(int i = 1; i <= n; ++ i) {
        scanf("%d",&val[i]);
        head[i] = -1;
    }
    for(int i = 1; i < n; ++ i){
        int x, y; scanf("%d%d",&x,&y);
        add(x, y);
    }
    dfs1(r, r);
    dfs2(r, r);
    Build(1, n, 1);
    while(m --){
        int k, x, y, z;
        scanf("%d",&k);
        if(k == 1){ //x 到 y最短路径上点值 + z
            scanf("%d%d%d",&x,&y,&z);
            updRange(x, y, z);
        }
        else if(k == 2){    //x 到 y树上最短路径结点权值和
            scanf("%d%d",&x,&y);
            printf("%d\n",qRange(x, y));
        }
        else if(k == 3){    //以x为根的子树内所有节点值 + z
            scanf("%d%d",&x,&y);
            updSon(x, y);
        }
        else{   //以x为根结点的子树内所有节点值之和
            scanf("%d",&x);
            printf("%d\n",qSon(x));
        }
    }
    return 0;
}

posted @ 2020-08-15 21:32  A_sc  阅读(119)  评论(0编辑  收藏  举报