A Simple Problem On A Tree(The 2019 ICPC Asia Shanghai Regional Contest)
题目来源
https://ac.nowcoder.com/acm/contest/4370/F
题意分析
给出一棵树,有四种操作:
1 x y w 表示将从x到y这条简单路径的上所有点权改成w。
2 x y w 表示将从x到y这条简单路径的上所有点权加上w。
3 x y w 表示将从x到y这条简单路径的上所有点权乘上w。
4 x y 表示求出从x到y这条简单路径上的所有点的点权的立方和。
思路分析
学过树链剖分的,其实一看就会发现是树链剖分的模板题。主要难题在如何处理点权的修改和乘法。
首先手动模拟一下立方的加法和乘法是什么情况,进而决定用线段树需要维护的数列。首先需要维护的是立方和的加法,而在立方和加法中,如果对于其中的某个值进行加法运算,那么去掉括号之后会发现他的值和平方和以及一次方和有关,所以再维护一下这两个数值。然后就是长时间的码农时间了。
code
#include <bits/stdc++.h> #define ll long long using namespace std; const int maxn = 1e5 +7; const int mod = 1e9 + 7; int n, m, r; int head[maxn], nxt[maxn << 1], ver[maxn << 1]; int sz[maxn], dep[maxn], fa[maxn], top[maxn], dfn[maxn], id[maxn]; int hcnt = 0; int tot; ll sum1[maxn << 2], sum2[maxn << 2], sum3[maxn << 2], a[maxn]; ll lzmul[maxn << 2], lzadd[maxn << 2]; void adde(int u, int v){ ++tot; ver[tot] = v; nxt[tot] = head[u]; head[u] = tot; } void dfs1(int x, int f, int d){ fa[x] = f; dep[x] = d; sz[x] = 1; for (int i=head[x]; i; i=nxt[i]){ int v = ver[i]; if (v == f) continue; dfs1(v, x, d + 1); sz[x] += sz[v]; } } void dfs2(int x, int f){ dfn[x] = ++hcnt; id[hcnt] = x; top[x] = f; int pp = 0; for (int i=head[x]; i; i=nxt[i]){ int v = ver[i]; if (v == fa[x]) continue; if (sz[v] > sz[pp]) pp = v; } if (pp == 0) return; dfs2(pp, f); for (int i=head[x]; i; i=nxt[i]){ int v = ver[i]; if (v == pp || v == fa[x]) continue; dfs2(v, v); } } void pushup(int p){ sum1[p] = sum1[p << 1] + sum1[p << 1 | 1]; sum1[p] %= mod; sum2[p] = sum2[p << 1] + sum2[p << 1 | 1]; sum2[p] %= mod; sum3[p] = sum3[p << 1] + sum3[p << 1 | 1]; sum3[p] %= mod; } void change(int p, ll x, ll y, int ln){ if (x != 1){ ll w = 1ll * x % mod; ll w2 = w * w % mod; ll w3 = w2 * w % mod; sum3[p] *= w3; sum3[p] %= mod; sum2[p] *= w2; sum2[p] %= mod; sum1[p] *= w; sum1[p] %= mod; lzmul[p] *= w; lzmul[p] %= mod; lzadd[p] *= w; lzadd[p] %= mod; } if (y != 0){ ll w = 1ll * y % mod; ll w2 = w * w % mod; ll w3 = w2 * w % mod; sum3[p] += (1ll* ln * w3 % mod); sum3[p] %= mod; sum3[p] += (3ll * w2 * sum1[p]) % mod; sum3[p] %= mod; sum3[p] += (3ll * w * sum2[p]) % mod; sum3[p] %= mod; sum2[p] += (1ll * ln * w2 % mod); sum2[p] %= mod; sum2[p] += (2ll * w * sum1[p]) % mod; sum2[p] %= mod; sum1[p] += 1ll * ln * w; sum1[p] %= mod; lzadd[p] += w; lzadd[p] %= mod; } } void pushdown(int p, int lnl, int lnr){ ll x = lzmul[p], y = lzadd[p]; change(p << 1, x, y, lnl); change(p << 1 | 1, x, y, lnr); lzmul[p] = 1; lzadd[p] = 0; } void build(int p, int l, int r){ lzmul[p] = 1; lzadd[p] = 0; if (l == r){ sum1[p] = a[id[l]] % mod; sum2[p] = sum1[p] * sum1[p] % mod; sum3[p] = sum2[p] * sum1[p] % mod; return; } int mid = l + r >> 1; build(p << 1, l, mid); build(p << 1 | 1, mid + 1, r); pushup(p); } ll sum(int p, int l, int r, int L, int R){ // cout << p << " " << l << " " << r << " " << L << " " << R << endl; if (L <= l && r <= R) return sum3[p]; int mid = l + r >> 1; pushdown(p, mid + 1 - l, r - mid); ll ans = 0; if (mid >= L) ans += sum(p << 1, l, mid, L, R); ans %= mod; if (mid < R) ans += sum(p << 1 | 1, mid + 1, r, L, R); ans %= mod; pushup(p); return ans; } void update(int p, int l, int r, int L, int R, int x, int y){ // cout << p << " " << l << " " << r << " " << L << " " << R << " " << x << " " << y << endl; if (L <= l && r <= R){ change(p, 1ll*x, 1ll*y, r - l + 1); return; } int mid = l + r >> 1; pushdown(p, mid + 1 - l, r - mid); if (L <= mid) update(p << 1, l, mid, L, R, x, y); if (mid < R) update(p << 1 | 1, mid + 1, r, L, R, x, y); pushup(p); } void upd(int u, int v, int x, int y){ while (top[u] != top[v]){ if (dep[top[u]] < dep[top[v]]) swap(u, v); update(1, 1, n, dfn[top[u]], dfn[u], x, y); u = fa[top[u]]; } if (dep[u] > dep[v]) swap(u, v); update(1, 1, n, dfn[u], dfn[v], x, y); } ll ask(int u, int v){ ll ans = 0; while (top[u] != top[v]){ if (dep[top[u]] < dep[top[v]]) swap(u, v); // cout << "!!!! " << u << " " << dfn[top[u]] << " " << dfn[u] << endl; ans += sum(1, 1, n, dfn[top[u]], dfn[u]); ans %= mod; u = fa[top[u]]; } ans %= mod; if (dep[u] < dep[v]) ans += sum(1, 1, n, dfn[u], dfn[v]); else if (dep[u] >= dep[v]) ans += sum(1, 1, n, dfn[v], dfn[u]); ans %= mod; return ans; } int main(){ int t; scanf("%d", &t); int cas = 0; while (t --){ tot = 0; hcnt = 0; scanf("%d", &n); for (int i=0; i<=n; i++){ head[i] = 0; sz[i] = 0;id[i] = 0; fa[i] = 0; } int r = 1; // cout << "???" << endl; for (int i=1; i<n; i++){ int u, v; scanf("%d%d", &u, &v); adde(u, v); adde(v, u); } // cout << "1111" << endl; for (int i=1; i<=n; i++) scanf("%d", &a[i]); dfs1(r, 0, 1); // cout << "2222" << endl; dfs2(r, r); // cout << "3333" << endl; build(1, 1, n); // cout << "4444" << endl; int q; scanf("%d", &q); printf("Case #%d:\n", ++cas); while (q --){ int op; scanf("%d", &op); int u, v; scanf("%d%d", &u, &v); int w; if (op == 1){ scanf("%d", &w); upd(u, v, 0, w); }else if (op == 2){ scanf("%d", &w); upd(u, v, 1, w); }else if (op == 3){ scanf("%d", &w); upd(u, v, w, 0); }else if (op == 4){ printf("%d\n", ask(u, v) % mod); } } } return 0; }