H [HAOI2015]树上操作 树剖板子

链接:https://ac.nowcoder.com/acm/contest/27836/H
来源:牛客网

题目描述

有一棵点数为 N 的树,以点 1 为根,且树点有边权。
然后有 M 个 操作,分为三种: 
操作 1 :把某个节点 x 的点权增加 a 。
操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。
操作 3 :询问某个节点 x 到根的路径中所有点的点权和。

输入描述:

第一行包含两个整数N, M。表示点数和操作数。
接下来一行N个整数,表示树中节点的初始权值。
接下来N-1行每行三个正整数 fr, to ,表示该树中存在一条边 (fr, to) 。
再接下来M行,每行分别表示一次操作。其中第一个数表示该操作的种类( 1-3 ),之后接这个操作的参数( x 或者 x a ) 。

输出描述:

对于每个询问操作,输出该询问的答案。答案之间用换行隔开。
示例1

输入

复制
5 5
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3

输出

复制
6
9
13

备注:

对于 100% 的数据, N,M≤100000N,M\le 100000N,M100000 ,且所有输入数据的绝对值都不会超过 10^6 。



分析

 

板子题,求出根节点到某个节点的权值和,单点修改和区间修改

//-------------------------代码----------------------------

#define int ll
const int N = 1e5+10;
int n,m,w[N];
int fa[N],dep[N],sz[N],hson[N],top[N],dfn[N],rnk[N],last[N],tot;
V<int> edge[N];
inline ll read() { ll s = 0, w = 1; char ch = getchar(); for (; !isdigit(ch); ch = getchar()) if (ch == '-') w = -1; for (; isdigit(ch); ch = getchar())    s = (s << 1) + (s << 3) + (ch ^ 48); return s * w; }
struct Seg_tree {
#define mid (tr[u].l + tr[u].r >> 1)
#define ul (u<<1)
#define ur (u<<1|1)
#define len(u) (tr[u].r-tr[u].l+1)
    struct node {
        int l,r;
        int lazy;
        int sum;
    } tr[N<<2];
    
    void push_up(int u) {
        tr[u].sum = tr[ul].sum + tr[ur].sum;
    }
    
    void push_down(int u) {
        if(tr[u].lazy == 0) return;
        tr[ul].sum += tr[u].lazy *  len(ul);
        tr[ur].sum += tr[u].lazy *  len(ur);
        tr[ul].lazy += tr[u].lazy;
        tr[ur].lazy += tr[u].lazy;
        tr[u].lazy = 0;
    }
    
    void build(int u,int l,int r) {
        tr[u] = {l,r,0};
        if(l == r) {
            tr[u].sum = w[rnk[l]];return;
        }
        build(ul,l,mid);build(ur,mid+1,r);
        push_up(u);
    }
    
    void update(int u,int l,int r,int v) {
        if(l <= tr[u].l && tr[u].r <= r) {
            tr[u].sum += 1ll *  len(u) * v;
            tr[u].lazy += v;
            return;
        }
        if(tr[u].l > r || tr[u].r < l) return;
        push_down(u);
        update(ul,l,r,v);update(ur,l,r,v);
        push_up(u);
    }
    
    ll query_sum(int u,int l,int r) {
        if(l <= tr[u].l && tr[u].r <= r) {return tr[u].sum;}
        if(tr[u].l > r || tr[u].r < l) {return 0;}
        push_down(u);
        return query_sum(ul,l,r) + query_sum(ur,l,r);
    }
} A;

void dfs1(int u) {
    hson[u] = -1;
    sz[u] = 1;
    for (auto& v : edge[u]) {
        if (dep[v])    continue;
        dep[v] = dep[u] + 1;
        fa[v] = u;
        dfs1(v);
        sz[u] += sz[v];
        if (hson[u] == -1 or sz[v] > sz[hson[u]])
            hson[u] = v;
    }
}

void dfs2(int u, int t) {
    top[u] = t;
    dfn[u] = ++tot;
    rnk[tot] = u;
    if (hson[u] != -1)    dfs2(hson[u], t);
    for (auto& v : edge[u])
        if (v != hson[u] and v != fa[u])    dfs2(v, v);
    last[u] = tot;
}

ll query_sum(int x, int y) {
    ll res = 0;
    int fx = top[x], fy = top[y];
    while (fx != fy) {
        if (dep[fx] >= dep[fy]) {
            res += A.query_sum(1, dfn[fx], dfn[x]);
            x = fa[fx];
        }
        else {
            res += A.query_sum(1, dfn[fy], dfn[y]);
            y = fa[fy];
        }
        fx = top[x];
        fy = top[y];
    }
    if (dfn[x] < dfn[y])
        res += A.query_sum(1, dfn[x], dfn[y]);
    else
        res += A.query_sum(1, dfn[y], dfn[x]);
    return res;
}

void solve()
{
    n = read(), m = read();
    fo(i, 1, n)    w[i] = read();
    fo(i, 2, n) {
        int u = read(), v = read();
        edge[u].push_back(v);
        edge[v].push_back(u);
    }
    dep[1] = 1;
    dfs1(1);
    dfs2(1, 1);
    A.build(1, 1, n);
    int op, x, y;
    while (m--) {
        op = read();
        if (op == 1) {
            x = read(), y = read();
            A.update(1, dfn[x], dfn[x], y);
        }
        else if (op == 2) {
            x = read(), y = read();
            A.update(1,dfn[x], last[x], y);
        }
        else {
            x = read();
            ll ans = query_sum(1, x);
            cout<<ans<<endl;
        }
    }
}
void main_init() {}
signed main(){
    AC();clapping();TLE;
    cout<<fixed<<setprecision(12);
    main_init();
//  while(cin>>n,n)
//  while(cin>>n>>m,n,m)
//    int t;cin>>t;while(t -- )
    solve();
//    {solve(); }
    return 0;
}

/*样例区


*/

//------------------------------------------------------------

 

posted @ 2022-09-05 21:37  er007  阅读(19)  评论(0编辑  收藏  举报