树链剖分

第一次码模板就出了一个小bug,改了就过了

传送门

#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 10;
int cnt, head[N];
struct edge
{
    int to, nex;
} e[N << 1];
inline void add_edge(int u, int v)
{
    e[++cnt].to = v;
    e[cnt].nex = head[u];
    head[u] = cnt;
}
int n, m, root, mod;
int a[N], w[N]; //初始点权, dfs序后的点权
int dfn[N], siz[N], son[N], f[N], top[N], dep[N], tim; //时间戳 子树大小 重儿子 父亲 当前链顶端节点
struct segment_tree
{
    struct tnode
    {
        int l, r, sum, lazy;
    };
    tnode t[N << 2];
    inline void push_down(int root)
    {
        if(t[root].lazy != 0)
        {
            t[root].sum += (t[root].lazy * (t[root].r - t[root].l + 1)) % mod;
            if(t[root].l != t[root].r)
            {
                int ch = root << 1;
                t[ch].lazy += t[root].lazy;
                t[ch + 1].lazy += t[root].lazy;
            }
            t[root].lazy = 0;
        }
    }
    inline void push_up(int root)
    {
        int ch = root << 1;
        push_down(ch);
        push_down(ch + 1);
        t[root].sum = (t[ch].sum + t[ch + 1].sum) % mod;
    }
    inline void build(int root, int l, int r)
    {
        t[root].l = l, t[root].r = r;
        if(l != r)
        {
            int ch = root << 1;
            int mid = l + r >> 1;
            build(ch, l, mid);
            build(ch + 1, mid + 1, r);
            push_up(root);
        }
        else
        {
            t[root].lazy = 0;
            t[root].sum = w[l] % mod;
        }
    }
    void change(int root, int l, int r, int k)
    {
        push_down(root);
        if(t[root].l >= l && t[root].r <= r)
        {
            t[root].lazy += k;
            return;
        }
        int ch = root << 1;
        int mid = t[root].l + t[root].r >> 1;
        if(r <= mid) change(ch, l, r, k);
        else if(l > mid) change(ch + 1, l, r, k);
        else change(ch, l, mid, k), change(ch + 1, mid + 1, r, k);
        push_up(root);
    }
    int query(int root, int l, int r)
    {
        push_down(root);
        if(t[root].l >= l && t[root].r <= r)
        {
            return t[root].sum % mod;
        }
        int ch = root << 1;
        int mid = t[root].l + t[root].r >> 1;
        if(r <= mid) return query(ch, l, r) % mod;
        else if(l > mid) return query(ch + 1, l, r) % mod;
        else return (query(ch, l, mid) % mod + query(ch + 1, mid + 1, r) % mod) % mod;
    }
} st;
void dfs1(int u, int fa)
{
    f[u] = fa;
    dep[u] = dep[fa] + 1;
    siz[u] = 1;
    int maxsize = -1;
    for(int i = head[u]; i; i = e[i].nex)
    {
        int v = e[i].to;
        if(v == fa) continue;
        dfs1(v, u);
        siz[u] += siz[v];
        if(siz[v] > maxsize)
        {
            maxsize = siz[v];
            son[u] = v;
        }
    }
}
void dfs2(int u, int t)
{
    dfn[u] = ++tim;
    top[u] = t;
    w[tim] = a[u];
    if(!son[u]) return;
    dfs2(son[u], t);
    for(int i = head[u]; i; i = e[i].nex)
    {
        int v = e[i].to;
        if(v == f[u] || v == son[u]) continue;
        dfs2(v, v);
    }
}
inline void op1(int x, int y, int k)
{
    k %= mod;
    while(top[x] != top[y])
    {
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        st.change(1, dfn[top[x]], dfn[x], k);
        x = f[top[x]];
    }
    if(dep[x] > dep[y]) swap(x, y);
    st.change(1, dfn[x], dfn[y], k);
}
inline int op2(int x, int y)
{
    int cnt = 0;
    while(top[x] != top[y])
    {
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        cnt += st.query(1, dfn[top[x]], dfn[x]);
        x = f[top[x]];
    }
    if(dep[x] > dep[y]) swap(x, y);
    cnt += st.query(1, dfn[x], dfn[y]) % mod;
    return cnt % mod;
}
inline void op3(int x, int k)
{
    st.change(1, dfn[x], dfn[x] + siz[x] - 1, k);
}
inline int op4(int x)
{
    return st.query(1, dfn[x], dfn[x] + siz[x] - 1);
}

signed main()
{
    ios::sync_with_stdio(false), cin.tie(0);
    cin >> n >> m >> root >> mod;
    for(int i = 1; i <= n; ++i) cin >> a[i];
    int x, y;
    for(int i = 1; i < n; ++i)
    {
        cin >> x >> y;
        add_edge(x, y);
        add_edge(y, x);
    }
    dfs1(root, root);
    dfs2(root, root);
    st.build(1, 1, n);
    while(m--)
    {
        int op, z;
        cin >> op;
        if(op == 1)
        {
            cin >> x >> y >> z;
            op1(x, y, z);
        }
        else if(op == 2)
        {
            cin >> x >> y;
            cout << op2(x, y) << "\n";
        }
        else if(op == 3)
        {
            cin >> x >> y;
            op3(x, y);
        }
        else
        {
            cin >> x;
            cout << op4(x) << "\n";
        }
    }
    return 0;
}

posted @ 2022-06-29 19:55  std&ice  阅读(78)  评论(0)    收藏  举报