数据结构--树链剖分

树链剖分

轻重链剖分

利用dfs序,不停的跳链的顶端,直到跳到同一条重链上。

#include <bits/stdc++.h>
using namespace std;
#define mem(a,b) memset(a,b,sizeof(a))
#define pii pair<int,int>
#define int long long
const int inf = 0x3f3f3f3f;
const int maxn = 201110;
int n,m,r,M;

int a[maxn];

int head[maxn],to[maxn],Next[maxn],cnt = 2;

void add(int u,int v)
{
    to[cnt] = v;Next[cnt] = head[u];head[u] = cnt;cnt++;
}

int fa[maxn],sz[maxn],son[maxn],d[maxn];

void dfs1(int u,int pre)
{
    fa[u] = pre;
    sz[u] = 1;
    d[u] = d[pre]+1;
    int mx = -1;
    for(int i = head[u]; i ; i = Next[i])
    {
        int v = to[i];
        if(v == pre) continue;
        dfs1(v,u);
        sz[u] += sz[v];
        if(sz[v] > mx)
        {
            son[u] = v;mx = sz[v];
        }
    }
}

int dfn[maxn],top[maxn],w[maxn],idx;

void dfs2(int u,int pre)
{
    dfn[u] = ++idx;
    top[u] = pre;
    w[idx] = a[u];
    if(!son[u]) return;
    dfs2(son[u],pre);           //重儿子
    for(int i = head[u]; i ; i = Next[i])
    {
        int v = to[i];
        if(v == son[u] || v == fa[u]) continue;
        dfs2(v,v);               //  轻儿子
    }
}

struct tree_node
{
    int l,r,sum,lazy;
};

struct segtree
{
    tree_node t[maxn*4];
    void pushup(int k)
    {
        t[k].sum = (t[k*2].sum+t[k*2+1].sum)%M;
    }
    void pushdown(int k)
    {
        if(t[k].lazy)
        {
            t[k*2].lazy += t[k].lazy;
            t[k*2+1].lazy += t[k].lazy;
            t[k*2].sum += t[k].lazy*(t[k*2].r - t[k*2].l + 1);
            t[k*2+1].sum += t[k].lazy*(t[k*2+1].r - t[k*2+1].l + 1);
            t[k].lazy = 0;
        }
    }
    void build(int k,int l,int r)
    {
        t[k].l = l,t[k].r = r;
        if(l == r)
        {
            t[k].sum = w[l];
            return;
        }
        int mid = (l+r)/2;
        build(k*2,l,mid);
        build(k*2+1,mid+1,r);
        pushup(k);
    }
    void update(int k,int l,int r,int x)
    {   
        if(l > r) return;
        if(l <= t[k].l && t[k].r <= r)
        {
            t[k].sum = (t[k].sum +  x*(t[k].r-t[k].l+1))%M;
            t[k].lazy = (t[k].lazy + x)%M;
            return;
        }
        pushdown(k);
        if(t[k*2].r >= l) update(k*2,l,r,x);
        if(t[k*2+1].l <= r) update(k*2+1,l,r,x);
        pushup(k);
    }
    int query(int k,int l,int r)
    {
        if(l > r) return 0;
        if(l <= t[k].l && t[k].r <= r)
        {
            return t[k].sum;
        }
        int res = 0;
        pushdown(k);
        if(t[k*2].r >= l) res = (res + query(k*2,l,r))%M;
        if(t[k*2+1].l <= r) res = (res + query(k*2+1,l,r))%M;
        return res;
    }
}st;


void modify(int x,int y,int z)
{
    while(top[x] != top[y])
    {
        if(d[top[x]] < d[top[y]]) swap(x,y);
        st.update(1,dfn[top[x]],dfn[x],z);
        x = fa[top[x]];
    }
    if(d[x] > d[y]) swap(x,y);
    st.update(1,dfn[x],dfn[y],z);
}

int search(int x,int y)
{
    int res = 0;
    while(top[x] != top[y])
    {
        if(d[top[x]] < d[top[y]]) swap(x,y);
        res = (res + st.query(1,dfn[top[x]],dfn[x]))%M;
        x = fa[top[x]];
    }
    if(d[x] > d[y]) swap(x,y);
    res = (res + st.query(1,dfn[x],dfn[y]))%M;
    return res;
}

signed main()
{
#ifdef ONLINE_JUDGE
#else
    freopen("data.in", "r", stdin);
#endif
    cin>>n>>m>>r>>M;
    for(int i = 1; i <= n; i++) 
    {
        cin>>a[i];
    }
    for(int i = 1,x,y; i < n; i++) 
    {
        cin>>x>>y;
        add(x,y);add(y,x);
    }
    dfs1(r,r);dfs2(r,r);
    st.build(1,1,n);
    int opt,x,y,z;
    while(m--)
    {
        cin>>opt;
        if(opt == 1)
        {
            cin>>x>>y>>z;
            modify(x,y,z);
        }
        else if(opt == 2)
        {
            cin>>x>>y;
            cout<<search(x,y)<<endl;
        }
        else if(opt == 3)
        {
            cin>>x>>z;
            st.update(1,dfn[x],dfn[x]+sz[x]-1,z);
        }
        else if(opt == 4)
        {
            cin>>x;
            cout<<st.query(1,dfn[x],dfn[x]+sz[x]-1)<<endl;
        }
    }
    return 0;
}
posted @ 2020-01-22 11:45  hezongdnf  阅读(131)  评论(0编辑  收藏  举报