洛谷P3384 【模板】树链剖分

【模板】树链剖分

题目传送门

代码如下

#include <iostream>
#include <cstdio>
#include <vector>
#define maxn 100005
using namespace std;
typedef long long ll;
struct T{
    int data, next;
}e[maxn << 1];
int top[maxn], son[maxn], size[maxn], depth[maxn], data[maxn], fa[maxn];
int head[maxn], cnt;
vector<int> vec;
int p;
struct node{
    int l, r;
    ll sum;
    ll lazy;
}tree[maxn << 2];
void add(int x, int y)
{
    ++ cnt;
    e[cnt].data = y;
    e[cnt].next = head[x];
    head[x] = cnt;
}
void dfs1(int x)
{
    size[x] = 1;
    for(int i = head[x]; i != 0; i = e[i].next){
        int r = e[i].data;
        if(r != fa[x]){
            depth[r] = depth[x] + 1;
            fa[r] = x;
            dfs1(r);
            size[x] += size[r];
            if(!son[x] || size[r] > size[son[x]])
                son[x] = r;
        }
    }
}
int mp[maxn];
void dfs2(int x, int k)
{
    if(x == 0)
        return;
    top[x] = k;
    vec.push_back(x);
    mp[x] = vec.size() - 1;
    dfs2(son[x], k);
    for(int i = head[x]; i != 0; i = e[i].next){
        int r = e[i].data;
        if(r != fa[x] && r != son[x]){
            dfs2(r, r);
        }
    }
}
void build(int l, int r, int k)
{
    tree[k].l = l;
    tree[k].r = r;
    if(l == r){
        tree[k].sum = data[vec[l]];
        return;
    }
    int mid = (l + r) / 2;
    build(l, mid, 2*k);
    build(mid + 1, r, 2*k+1);
    tree[k].sum = tree[2*k].sum + tree[2*k+1].sum;
    tree[k].sum %= p;
}
void down(int k)
{
    if(tree[k].lazy == 0)
        return;
    tree[2*k].sum += (tree[2*k].r - tree[2*k].l + 1) * tree[k].lazy;
    tree[2*k+1].sum += (tree[2*k+1].r - tree[2*k+1].l + 1) * tree[k].lazy;
    tree[2*k].lazy +=  tree[k].lazy;
    tree[2*k + 1].lazy += tree[k].lazy;
    tree[k].lazy = 0;
}
void add(int l, int r, int z, int k)
{
    if(tree[k].l >= l && tree[k].r <= r){
        tree[k].sum += ((tree[k].r - tree[k].l + 1) * z) % p;
        tree[k].sum %= p;
        tree[k].lazy += z;
        return;
    }
    down(k);
    int mid = (tree[k].l + tree[k].r) / 2;
    if(l <= mid)
        add(l, r, z, 2*k);
    if(r > mid)
        add(l, r, z, 2*k+1);
    tree[k].sum = tree[2*k].sum + tree[2*k + 1].sum;
    tree[k].sum %= p;
}
void add1(int x, int y, int z)
{
    while(top[x] != top[y]){
        if(depth[top[x]] > depth[top[y]]){
            add(mp[top[x]], mp[x], z, 1);
            x = fa[top[x]];
        }
        else {
            add(mp[top[y]], mp[y], z, 1);
            y = fa[top[y]];
        }
    }
    if(depth[x] > depth[y])
        add(mp[y], mp[x], z, 1);
    else
        add(mp[x], mp[y], z, 1);
}
ll query(int l, int r, int k)
{
    if(tree[k].l >= l && tree[k].r <= r){
        return tree[k].sum;
    }
    down(k);
    int mid = (tree[k].l + tree[k].r) / 2;
    ll ans = 0;
    if(l <= mid)
        ans += query(l, r, 2*k), ans %= p;
    if(r > mid)
        ans += query(l, r, 2*k+1), ans %= p;;
    return ans;
}
ll get1(int x, int y)
{
    ll ans = 0;
    while(top[x] != top[y]){
        if(depth[top[x]] > depth[top[y]]){
            ans += query(mp[top[x]], mp[x], 1);
            ans %= p;
            x = fa[top[x]];
        }
        else {
            ans += query(mp[top[y]], mp[y], 1);
            ans %= p;
            y = fa[top[y]];
        }
    }
    if(depth[x] > depth[y])
        ans += query(mp[y], mp[x], 1), ans %= p;
    else
        ans += query(mp[x], mp[y], 1), ans %= p;
    return ans;
}
inline void add2(int x, int y)
{
    add(mp[x], mp[x] + size[x] - 1, y, 1);
}
inline ll get2(int x)
{
    return query(mp[x], mp[x] + size[x] - 1, 1) % p;
}
int main()
{
    int n, m, r;
    scanf("%d%d%d%d", &n, &m, &r, &p);
    for(int i = 1; i <= n; i ++)
        scanf("%d", &data[i]);
    for(int i = 1; i < n; i ++){
        int x, y;
        scanf("%d%d", &x, &y);
        add(x, y);
        add(y, x);
    }
    vec.push_back(0);
    dfs1(r);
    dfs2(r, r);
    build(1, n, 1);
    for(int i = 1; i <= m; i ++){
        int opt;
        scanf("%d", &opt);
        if(opt == 1){
            int x, y, z;
            scanf("%d%d%d", &x, &y, &z);
            add1(x, y, z);
        }
        else if(opt == 2){
            int x, y;
            scanf("%d%d", &x, &y);
            printf("%lld\n", get1(x, y) % p);
        }
        else if(opt == 3){
            int x, y;
            scanf("%d%d", &x, &y);
            add2(x, y);
        }
        else {
            int x;
            scanf("%d", &x);
            printf("%lld\n", get2(x) % p);
        }
    }
    return 0;
}
posted @ 2019-08-27 20:29  whisperlzw  阅读(162)  评论(0编辑  收藏  举报