A Simple Problem On A Tree(The 2019 ICPC Asia Shanghai Regional Contest)

题目来源

https://ac.nowcoder.com/acm/contest/4370/F

题意分析

  给出一棵树,有四种操作:

  1 x y w 表示将从x到y这条简单路径的上所有点权改成w。
  2 x y w 表示将从x到y这条简单路径的上所有点权加上w。
  3 x y w 表示将从x到y这条简单路径的上所有点权乘上w。
  4 x y  表示求出从x到y这条简单路径上的所有点的点权的立方和。

思路分析

  学过树链剖分的,其实一看就会发现是树链剖分的模板题。主要难题在如何处理点权的修改和乘法。
  首先手动模拟一下立方的加法和乘法是什么情况,进而决定用线段树需要维护的数列。首先需要维护的是立方和的加法,而在立方和加法中,如果对于其中的某个值进行加法运算,那么去掉括号之后会发现他的值和平方和以及一次方和有关,所以再维护一下这两个数值。然后就是长时间的码农时间了。

code

#include <bits/stdc++.h>

#define ll long long
using namespace std;
const int maxn = 1e5 +7;
const int mod = 1e9 + 7;

int n, m, r;
int head[maxn], nxt[maxn << 1], ver[maxn << 1];
int sz[maxn], dep[maxn], fa[maxn], top[maxn], dfn[maxn], id[maxn];
int hcnt = 0;
int tot;
ll sum1[maxn << 2], sum2[maxn << 2], sum3[maxn << 2], a[maxn];
ll lzmul[maxn << 2], lzadd[maxn << 2];

void adde(int u, int v){
    ++tot; ver[tot] = v; nxt[tot] = head[u]; head[u] = tot;
}

void dfs1(int x, int f, int d){
    fa[x] = f; dep[x] = d; sz[x] = 1;
    for (int i=head[x]; i; i=nxt[i]){
        int v = ver[i];
        if (v == f) continue;
        dfs1(v, x, d + 1);
        sz[x] += sz[v];
    }
}

void dfs2(int x, int f){
    dfn[x] = ++hcnt; id[hcnt] = x;
    top[x] = f;
    int pp = 0;
    for (int i=head[x]; i; i=nxt[i]){
        int v = ver[i];
        if (v == fa[x]) continue;
        if (sz[v] > sz[pp]) pp = v;
    }
    if (pp == 0) return;
    dfs2(pp, f);
    for (int i=head[x]; i; i=nxt[i]){
        int v = ver[i];
        if (v == pp || v == fa[x]) continue;
        dfs2(v, v);
    }
}

void pushup(int p){
    sum1[p] = sum1[p << 1] + sum1[p << 1 | 1];
    sum1[p] %= mod;
    sum2[p] = sum2[p << 1] + sum2[p << 1 | 1];
    sum2[p] %= mod;
    sum3[p] = sum3[p << 1] + sum3[p << 1 | 1];
    sum3[p] %= mod;
}

void change(int p, ll x, ll y, int ln){
    if (x != 1){
        ll w = 1ll * x % mod;
        ll w2 = w * w % mod;
        ll w3 = w2 * w % mod;
        sum3[p] *= w3; sum3[p] %= mod;
        sum2[p] *= w2; sum2[p] %= mod;
        sum1[p] *= w; sum1[p] %= mod;

        lzmul[p] *= w; lzmul[p] %= mod;
        lzadd[p] *= w; lzadd[p] %= mod;
    }
    if (y != 0){
        ll w = 1ll * y % mod;
        ll w2 = w * w % mod;
        ll w3 = w2 * w % mod;

        sum3[p] += (1ll* ln * w3 % mod); sum3[p] %= mod;
        sum3[p] += (3ll * w2 * sum1[p]) % mod; sum3[p] %= mod;
        sum3[p] += (3ll * w * sum2[p]) % mod; sum3[p] %= mod;

        sum2[p] += (1ll * ln * w2 % mod); sum2[p] %= mod;
        sum2[p] += (2ll * w * sum1[p]) % mod; sum2[p] %= mod;

        sum1[p] += 1ll * ln * w; sum1[p] %= mod;
        lzadd[p] += w; lzadd[p] %= mod;
    }
}

void pushdown(int p, int lnl, int lnr){
    ll x = lzmul[p], y = lzadd[p];
    change(p << 1, x, y, lnl);
    change(p << 1 | 1, x, y, lnr);
    lzmul[p] = 1; lzadd[p] = 0;
}

void build(int p, int l, int r){
    lzmul[p] = 1; lzadd[p] = 0;
    if (l == r){
        sum1[p] = a[id[l]] % mod;
        sum2[p] = sum1[p] * sum1[p] % mod;
        sum3[p] = sum2[p] * sum1[p] % mod;
        return;
    }
    int mid = l + r >> 1;
    build(p << 1, l, mid);
    build(p << 1 | 1, mid + 1, r);
    pushup(p);
}

ll sum(int p, int l, int r, int L, int R){
//    cout << p << " " << l << " " << r << " " << L << " " << R << endl;
    if (L <= l && r <= R) return sum3[p];
    int mid = l + r >> 1;
    pushdown(p, mid + 1 - l, r - mid);
    ll ans = 0;
    if (mid >= L) ans += sum(p << 1, l, mid, L, R);
    ans %= mod;
    if (mid < R) ans += sum(p << 1 | 1, mid + 1, r, L, R);
    ans %= mod;
    pushup(p);
    return ans;
}

void update(int p, int l, int r, int L, int R, int x, int y){
//    cout << p << " " << l << " " << r << " " << L << " " << R << " " << x << " " << y << endl;
    if (L <= l && r <= R){
        change(p, 1ll*x, 1ll*y, r - l + 1); return;
    }
    int mid = l + r >> 1;
    pushdown(p, mid + 1 - l, r - mid);
    if (L <= mid) update(p << 1, l, mid, L, R, x, y);
    if (mid < R) update(p << 1 | 1, mid + 1, r, L, R, x, y);
    pushup(p);
}

void upd(int u, int v, int x, int y){
    while (top[u] != top[v]){
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        update(1, 1, n, dfn[top[u]], dfn[u], x, y);
        u = fa[top[u]];
    }
    if (dep[u] > dep[v]) swap(u, v);
    update(1, 1, n, dfn[u], dfn[v], x, y);
}

ll ask(int u, int v){
    ll ans = 0;
    while (top[u] != top[v]){
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
//        cout << "!!!!  " << u << " " << dfn[top[u]] << " " << dfn[u] << endl;
        ans += sum(1, 1, n, dfn[top[u]], dfn[u]);
        ans %= mod;
        u = fa[top[u]];
    }
    ans %= mod;
    if (dep[u] < dep[v]) ans += sum(1, 1, n, dfn[u], dfn[v]);
    else if (dep[u] >= dep[v]) ans += sum(1, 1, n, dfn[v], dfn[u]);
    ans %= mod;
    return ans;
}



 int main(){
    int t; scanf("%d", &t);
    int cas = 0;
    while (t --){
        tot = 0; hcnt = 0;
        scanf("%d", &n);
        for (int i=0; i<=n; i++){
            head[i] = 0; sz[i] = 0;id[i] = 0; fa[i] = 0;
        }
        int r = 1;

//        cout << "???" << endl;
        for (int i=1; i<n; i++){
            int u, v;
            scanf("%d%d", &u, &v);
            adde(u, v); adde(v, u);
        }
//        cout << "1111" << endl;
        for (int i=1; i<=n; i++) scanf("%d", &a[i]);
        dfs1(r, 0, 1);
//        cout << "2222" << endl;
        dfs2(r, r);
//        cout << "3333" << endl;
        build(1, 1, n);
//        cout << "4444" << endl;
        int q; scanf("%d", &q);
        printf("Case #%d:\n", ++cas);
        while (q --){
            int op; scanf("%d", &op);
            int u, v; scanf("%d%d", &u, &v);
            int w;
            if (op == 1){
                scanf("%d", &w);
                upd(u, v, 0, w);
            }else if (op == 2){
                scanf("%d", &w);
                upd(u, v, 1, w);
            }else if (op == 3){
                scanf("%d", &w);
                upd(u, v, w, 0);
            }else if (op == 4){
                printf("%d\n", ask(u, v) % mod);
            }
        }
    }
    return 0;
}

 

posted @ 2020-12-08 21:53  Rain_island  阅读(91)  评论(0编辑  收藏  举报
Title