【模板】树链剖分

题意简述

如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
操作1:1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
操作2:2 x y 表示求树从x到y结点最短路径上所有节点的值之和
操作3:3 x z 表示将以x为根节点的子树内所有节点值都加上z
操作4:4 x 表示求以x为根节点的子树内所有节点值之和

代码

#include <cstdio>
#include <algorithm>
#define ci const int
const int Maxn = 100010;
typedef long long ll;
int n, m, r, mod, cnt1, cnt2, cnt3, u, v, opt, x, y, z;
int h[Maxn], to[Maxn << 1], nxt[Maxn << 1];
int va[Maxn], w[Maxn], fa[Maxn], hvs[Maxn], dep[Maxn], sz[Maxn], id[Maxn], top[Maxn];
char ch;
void _add(int& x, const int k) {x = ((ll)x + k) % mod;}
inline void add_edge(ci& u, ci& v)
{
    to[++cnt1] = v;
    nxt[cnt1] = h[u];
    h[u] = cnt1;
}
void dfs1(ci& u)
{
    sz[u] = 1;
    for (register int i = h[u]; i; i = nxt[i])
        if (to[i] ^ fa[u])
        {
            fa[to[i]] = u;
            dep[to[i]] = dep[u] + 1;
            dfs1(to[i]);
            sz[u] += sz[to[i]];
            if (sz[hvs[u]] < sz[to[i]]) hvs[u] = to[i];
        }
}
void dfs2(ci& u, ci& tp)
{
    id[u] = ++cnt2;
    va[cnt2] = w[u];
    top[u] = tp;
    if (hvs[u]) dfs2(hvs[u], tp);
    for (register int i = h[u]; i; i = nxt[i])
        if (to[i] ^ fa[u] && to[i] ^ hvs[u])
            dfs2(to[i], to[i]);
}
struct Segment_Tree
{
    int a[Maxn << 2], la[Maxn << 2];
    void push_up(ci& x) {a[x] = ((ll)a[x << 1] + a[x << 1 | 1]) % mod; }
    void push_down(ci& x, ci& len)
    {
        _add(a[x << 1], la[x] * (len - (len >> 1)) % mod);
        _add(a[x << 1 | 1], la[x] * (len >> 1) % mod);
        _add(la[x << 1], la[x]);
        _add(la[x << 1 | 1], la[x]);
        la[x] = 0;
    }
    void build(ci& x, ci& l, ci& r)
    {
        if (l == r)	{a[x] = va[++cnt3]; return; }
        int mid = (l + r) >> 1;
        build(x << 1, l, mid);
        build(x << 1 | 1, mid + 1, r);
        push_up(x);
    }
    void add(ci& x, ci& l, ci& r, ci& l1, ci& r1, ci& k)
    {
        if (l1 <= l && r <= r1)
        {
            _add(a[x], (r - l + 1) * k % mod);
            _add(la[x], k);
            return;
        }
        if (la[x]) push_down(x, r - l + 1);
        int mid = (l + r) >> 1;
        if (l1 <= mid) add(x << 1, l, mid, l1, r1, k);
        if (r1 >  mid) add(x << 1 | 1, mid + 1, r, l1, r1, k);
        push_up(x);
    }
    int query(ci& x, ci& l, ci& r, ci& l1, ci& r1, int ans = 0)
    {
        if (l1 <= l && r <= r1)	return a[x];
        if (la[x]) push_down(x, r - l + 1);
        int mid = (l + r) >> 1;
        if (l1 <= mid) _add(ans, query(x << 1, l, mid, l1, r1));
        if (r1 >  mid) _add(ans, query(x << 1 | 1, mid + 1, r, l1, r1));
        return ans;
    }
}seg;
inline void add1(int x, int y, ci& z)
{
    for (; top[x] ^ top[y]; x = fa[top[x]])
    {
        if (dep[top[x]] < dep[top[y]]) std::swap(x, y);
        seg.add(1, 1, n, id[top[x]], id[x], z);
    }
    if (dep[x] > dep[y]) std::swap(x, y);
    seg.add(1, 1, n, id[x], id[y], z);
}
inline int query1(int x, int y, int s = 0)
{
	for (; top[x] ^ top[y]; x = fa[top[x]])
	{
		if (dep[top[x]] < dep[top[y]]) std::swap(x, y);
		_add(s, seg.query(1, 1, n, id[top[x]], id[x]));
	}
	if (dep[x] > dep[y]) std::swap(x, y);
	_add(s, seg.query(1, 1, n, id[x], id[y]));
	return s;
}
inline void add2(ci& x, ci& z)
{
	seg.add(1, 1, n, id[x], id[x] + sz[x] - 1, z);
}
inline int query2(ci& x)
{
	return seg.query(1, 1, n, id[x], id[x] + sz[x] - 1);
}
int main()
{
    scanf("%d%d%d%d", &n, &m, &r, &mod);
    for (register int i = 1; i <= n; ++i) scanf("%d", &w[i]);
    for (register int i = 1; i < n; ++i)
    {
        scanf("%d%d", &u, &v);
        add_edge(u, v); add_edge(v, u); 
    }
    fa[r] = r; dfs1(r); dfs2(r, r); seg.build(1, 1, n);
    for (register int i = 1; i <= m; ++i)
    {
        scanf("%d", &opt);
        if (opt == 1) {scanf("%d%d%d", &x, &y, &z); add1(x, y, z % mod); }
        else if (opt == 2) {scanf("%d%d", &x, &y); printf("%d\n", query1(x, y)); }
        else if (opt == 3) {scanf("%d%d", &x, &z); add2(x, z % mod); }
        else {scanf("%d", &x); printf("%d\n", query2(x)); }
    }
}
posted @ 2018-11-08 16:09  xuyixuan  阅读(126)  评论(0编辑  收藏  举报