2023湖南省赛 E.ytree (线段树)
大致思路:
1. 将操作1拆分为两个部分x(-1)^d + kd(-1)^d 。对于操作1中的x(-1)^d部分而言。我们可以对式子进行拆分,把x拆出来,我们会发现和v号点距离为奇数的点会减去x,为偶数的点会加上x,所以我们可以在线段树上用一个sum1维护应该减去的值,sum2维护加上的值即可。
2. 随即就是如何维护线段树不同结点之间的sum1和sum2了。我们将整棵树按照dfs序建树,如此一来一颗子树的dfs序是会一段连续的区间,我们在线段树上维护结点的深度最小值mn,当我们将父节点fa上的标记下传到子节点son的时候就可以根据父节点和子节点的最小深度差来下传标记,如果son.mn - fa.mn是奇数,那么son这个结点应该加上的其实是在fa减去的x的值总和,应该减去的其实是在fa加上的x的值的总和,所以就应该这样更新son.sum1 += fa.sum2, son.sum1 += fa.sum2。如果是偶数同理推导
3.再看1中的kd(-1)^d部分如何维护。维护k1和k2两个值,分别表示应该减去的k的总和 和 应该加上的k的总和。还是考虑如何下传标记。当我们将父节点fa上的标记下传到子节点son的时候同样可以根据父节点和子节点的最小深度差来下传标记。如果son.mn - fa.mn = d是奇数,那么son这个结点的sum1应该加上d * fa.k2, sum2应该加上d * fa.k1, son.k1 += fa.k2, son.k2 += fa.k1。偶数同理推导。
在d为奇数的时候,将2和3中的两个式子合并就是son.sum1 = son.sum1 + fa.sum2 + fa.k2 * d, son.sum2 = son.sum2 + fa.sum1 + fa.k1 * d。
4. 操作2就是线段树的单点查询。
5. 操作3我们可以创建一个multiset<array<int, 3>>将每个操作1按照{v的dfs序, x, k}的顺序丢进set。当我们遇到操作3中的v的时候只需要调用set的lowerbound来查找v的dfs序第一个出现的位置,并将这个提取进行修改操作,修改完之后从multiset删除即可。
#include <bits/stdc++.h>
const int N = 2e5 + 10;
const int MOD = 1e9 + 7;
using ll = long long;
typedef std::array<int, 3> PII;
int n, m;
#define ls u << 1
#define rs u << 1 | 1
int w[N], h[N], e[N], ne[N], idx;
int id[N], cnt;
int dep[N], sz[N], dfn[N];
struct Node {
int l, r;
int k1, k2;//k1奇数k的和,k2偶数k的和
int mn;//dfs序最小的点的深度
int sum1, sum2;//sum1奇数和, sum2偶数和
}tr[N << 2];
inline void add(int &x) {
if (x >= MOD) x -= MOD;
x += MOD;
if (x >= MOD) x -= MOD;
}
inline void pushup(int u) {
tr[u].mn = std::min(tr[ls].mn, tr[rs].mn);
}
inline void pushdown(int u) {
int d1 = tr[ls].mn - tr[u].mn;
if (d1 & 1) {
tr[ls].sum1 = (1ll * tr[ls].sum1 + tr[u].sum2 + 1ll * tr[u].k2 * d1) % MOD;
tr[ls].sum2 = (1ll * tr[ls].sum2 + tr[u].sum1 + 1ll * tr[u].k1 * d1) % MOD;
tr[ls].k1 += tr[u].k2;
tr[ls].k2 += tr[u].k1;
add(tr[ls].sum2), add(tr[ls].sum1), add(tr[ls].k1), add(tr[ls].k2);
} else {
tr[ls].sum1 = (1ll * tr[ls].sum1 + tr[u].sum1 + 1ll * tr[u].k1 * d1) % MOD;
tr[ls].sum2 = (1ll * tr[ls].sum2 + tr[u].sum2 + 1ll * tr[u].k2 * d1) % MOD;
tr[ls].k1 += tr[u].k1;
tr[ls].k2 += tr[u].k2;
add(tr[ls].sum2), add(tr[ls].sum1), add(tr[ls].k1), add(tr[ls].k2);
}
int d2 = tr[rs].mn - tr[u].mn;
if (d2 & 1) {
tr[rs].sum1 = (1ll * tr[rs].sum1 + tr[u].sum2 + 1ll * tr[u].k2 * d2) % MOD;
tr[rs].sum2 = (1ll * tr[rs].sum2 + tr[u].sum1 + 1ll * tr[u].k1 * d2) % MOD;
tr[rs].k1 += tr[u].k2;
tr[rs].k2 += tr[u].k1;
add(tr[rs].sum1), add(tr[rs].sum2), add(tr[rs].k1), add(tr[rs].k2);
} else {
tr[rs].sum1 = (1ll * tr[rs].sum1 + tr[u].sum1 + 1ll * tr[u].k1 * d2) % MOD;
tr[rs].sum2 = (1ll * tr[rs].sum2 + tr[u].sum2 + 1ll * tr[u].k2 * d2) % MOD;
tr[rs].k1 += tr[u].k1;
tr[rs].k2 += tr[u].k2;
add(tr[rs].sum1), add(tr[rs].sum2), add(tr[rs].k1), add(tr[rs].k2);
}
tr[u].sum1 = tr[u].sum2 = tr[u].k1 = tr[u].k2 = 0;
}
inline void build(int u, int l, int r){
tr[u] = {l, r};
if(l == r) {
tr[u].mn = dep[dfn[l]];
return ;
}
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
inline void init(){
memset(h, -1, sizeof h);
}
inline void add(int a, int b){
e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}
int rr[N];
inline void dfs(int u, int father, int depth){
dep[u] = depth, id[u] = ++ cnt, sz[u] = 1;
dfn[cnt] = u;
for (int i = h[u]; ~i; i = ne[i]){
int j = e[i];
if (j == father) continue;
dfs(j, u, depth + 1);
sz[u] += sz[j];
}
rr[u] = cnt;
}
inline void modify(int u, int L, int R, int x, int k, int depth) {
if (tr[u].l >= L && tr[u].r <= R) {
int d = tr[u].mn - depth;
if (d & 1) {
tr[u].sum1 = (ll(tr[u].sum1) + x + 1ll * d * k) % MOD;
tr[u].k1 += k;
add(tr[u].sum1);
add(tr[u].k1);
} else {
tr[u].sum2 = (ll(tr[u].sum2) + x + 1ll * d * k) % MOD;
tr[u].k2 += k;
add(tr[u].sum2);
add(tr[u].k2);
}
return ;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (L <= mid) modify(ls, L, R, x, k, depth);
if (R > mid) modify(rs, L, R, x, k, depth);
}
inline int query(int u, int x) {
if (tr[u].l == tr[u].r) return (tr[u].sum2 - tr[u].sum1 + MOD) % MOD;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) return query(ls, x);
return query(rs, x);
}
inline void solve() {
memset(h, -1, sizeof h);
std::cin >> n >> m;
for (int i = 2; i <= n; i ++) {
int x;
std::cin >> x;
add(x, i);
}
dfs(1, -1, 1);
std::multiset<PII> st;
build(1, 1, n);
constexpr int INF = 0x3f3f3f3f;
auto get = [&](int sb) {
auto it = st.lower_bound({sb, -INF, -INF});
if (it == st.end()) {
PII sn = {10000000, 0, 0};
return sn;
}
return *it;
};
while (m --) {
int op;
std::cin >> op;
if (op == 1) {
int x, v, k;
std::cin >> v >> x >> k;
modify(1, id[v], id[v] + sz[v] - 1, x, k, dep[v]);
st.insert({id[v], x, k});
} else if (op == 2){
int v;
std::cin >> v;
std::cout << query(1, id[v]) << '\n';
} else {
int z;
std::cin >> z;
for (int t = get(id[z])[0]; t <= rr[z]; t = get(t)[0]) {
auto [dfnn, x, k] = get(t);
modify(1, dfnn, dfnn + sz[dfn[dfnn]] - 1, -x, -k, dep[dfn[dfnn]]);
st.erase(st.find({dfnn, x, k}));
}
}
}
}
signed main(void) {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
int _ = 1;
//std::cin >> _;
while (_ --) solve();
return 0;
}
说一个可能比较难发现会错的写法。线段树上搞一个标记永久化。
LYL告诉我对于操作三标记永久化就行了,操作3回退的时候定位到区间将标记永久化移除,当时也是脑子抽了居然觉得很对。然后我就多开了一个sum3和sum4,k3和k4来标记永久化,还有tag来打标记。但是线段树上会出现这么一个问题,对于区间[1, 9]和区间[1, 10]其实都是在子区间[x1, y1], [x2, y2]这样的区间上打了标记,但是要回退的时候根本无法区分是[1, 10]在[x1, y1]上打的标记还是[1, 9]在[x1, y1]上打的标记。
#include <bits/stdc++.h>
const int N = 2e5 + 10;
const int MOD = 1e9 + 7;
using ll = long long;
typedef std::array<int, 3> PII;
int n, m;
#define ls u << 1
#define rs u << 1 | 1
int w[N], h[N], e[N], ne[N], idx;
int id[N], cnt;
int dep[N], sz[N], dfn[N];
struct node {
int l, r, tag;
std::vector<PII> tmp;
}trr[N << 2];
struct Node {
int l, r;
int k1, k2;//k1奇数k的和,k2偶数k的和
int mn;//dfs序最小的点的深度
int sum1, sum2, sum3, sum4;//sum1奇数和, sum2偶数和
int k3, k4;
int tag;
}tr[N << 2];
inline void add(int &x) {
if (x >= MOD) x -= MOD;
x += MOD;
if (x >= MOD) x -= MOD;
}
inline void pushup(int u) {
tr[u].mn = std::min(tr[ls].mn, tr[rs].mn);
}
inline void pushdown(int u) {
int d1 = tr[ls].mn - tr[u].mn;
if (d1 & 1) {
tr[ls].sum1 = (1ll * tr[ls].sum1 + tr[u].sum2 + 1ll * tr[u].k2 * d1) % MOD;
tr[ls].sum2 = (1ll * tr[ls].sum2 + tr[u].sum1 + 1ll * tr[u].k1 * d1) % MOD;
tr[ls].k1 += tr[u].k2;
tr[ls].k2 += tr[u].k1;
add(tr[ls].sum2), add(tr[ls].sum1), add(tr[ls].k1), add(tr[ls].k2);
} else {
tr[ls].sum1 = (1ll * tr[ls].sum1 + tr[u].sum1 + 1ll * tr[u].k1 * d1) % MOD;
tr[ls].sum2 = (1ll * tr[ls].sum2 + tr[u].sum2 + 1ll * tr[u].k2 * d1) % MOD;
tr[ls].k1 += tr[u].k1;
tr[ls].k2 += tr[u].k2;
add(tr[ls].sum2), add(tr[ls].sum1), add(tr[ls].k1), add(tr[ls].k2);
}
int d2 = tr[rs].mn - tr[u].mn;
if (d2 & 1) {
tr[rs].sum1 = (1ll * tr[rs].sum1 + tr[u].sum2 + 1ll * tr[u].k2 * d2) % MOD;
tr[rs].sum2 = (1ll * tr[rs].sum2 + tr[u].sum1 + 1ll * tr[u].k1 * d2) % MOD;
tr[rs].k1 += tr[u].k2;
tr[rs].k2 += tr[u].k1;
add(tr[rs].sum1), add(tr[rs].sum2), add(tr[rs].k1), add(tr[rs].k2);
} else {
tr[rs].sum1 = (1ll * tr[rs].sum1 + tr[u].sum1 + 1ll * tr[u].k1 * d2) % MOD;
tr[rs].sum2 = (1ll * tr[rs].sum2 + tr[u].sum2 + 1ll * tr[u].k2 * d2) % MOD;
tr[rs].k1 += tr[u].k1;
tr[rs].k2 += tr[u].k2;
add(tr[rs].sum1), add(tr[rs].sum2), add(tr[rs].k1), add(tr[rs].k2);
}
tr[u].sum1 = tr[u].sum2 = tr[u].k1 = tr[u].k2 = 0;
}
inline void build(int u, int l, int r){
trr[u] = {l, r};
tr[u] = {l, r};
if(l == r) {
tr[u].mn = dep[dfn[l]];
return ;
}
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
inline void init(){
memset(h, -1, sizeof h);
}
inline void add(int a, int b){
e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}
inline void dfs(int u, int father, int depth){
dep[u] = depth, id[u] = ++ cnt, sz[u] = 1;
dfn[cnt] = u;
for (int i = h[u]; ~i; i = ne[i]){
int j = e[i];
if (j == father) continue;
dfs(j, u, depth + 1);
sz[u] += sz[j];
}
}
inline void modify(int u, int L, int R, int x, int k, int depth) {
if (tr[u].l >= L && tr[u].r <= R) {
int d = tr[u].mn - depth;
if (d & 1) {
tr[u].sum1 = (ll(tr[u].sum1) + x + 1ll * d * k) % MOD;
tr[u].k1 += k;
add(tr[u].sum1);
add(tr[u].k1);
} else {
tr[u].sum2 = (ll(tr[u].sum2) + x + 1ll * d * k) % MOD;
tr[u].k2 += k;
add(tr[u].sum2);
add(tr[u].k2);
}
return ;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (L <= mid) modify(ls, L, R, x, k, depth);
if (R > mid) modify(rs, L, R, x, k, depth);
}
inline void pushdown1(int u) {
if (trr[u].tag) {
trr[ls].tag = trr[rs].tag = 1;
trr[ls].tmp.clear(), trr[rs].tmp.clear();
trr[u].tag = 0;
}
}
inline std::vector<PII> query(int u, int L, int R) {
if (trr[u].l >= L && trr[u].r <= R) {
trr[u].tag = 1;
std::vector<PII> now = trr[u].tmp;
trr[u].tmp.clear();
return now;
}
pushdown1(u);
std::vector<PII> cur, cur1, cur2;
int mid = trr[u].l + trr[u].r >> 1;
if (L <= mid) cur1 = query(ls, L, R);
if (R > mid) cur2 = query(rs, L, R);
cur.insert(cur.end(), cur1.begin(), cur1.end());
cur.insert(cur.end(), cur2.begin(), cur2.end());
return cur;
}
inline void modify(int u, int x, int v, int xx, int k) {
trr[u].tmp.push_back({v, xx, k});
if (trr[u].l == trr[u].r) return ;
pushdown1(u);
int mid = trr[u].l + trr[u].r >> 1;
if (x <= mid) modify(ls, x, v, xx, k);
else modify(rs, x, v, xx, k);
}
inline int query(int u, int x) {
if (tr[u].l == tr[u].r) return (tr[u].sum2 % MOD - tr[u].sum1 % MOD + MOD) % MOD;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) return query(ls, x);
return query(rs, x);
}
inline void solve() {
memset(h, -1, sizeof h);
std::cin >> n >> m;
for (int i = 2; i <= n; i ++) {
int x;
std::cin >> x;
add(x, i);
}
dfs(1, -1, 1);
build(1, 1, n);
while (m --) {
int op;
std::cin >> op;
if (op == 1) {
int x, v, k;
std::cin >> v >> x >> k;
modify(1, id[v], id[v] + sz[v] - 1, x, k, dep[v]);
modify(1, id[v], v, x, k);
} else if (op == 2){
int v;
std::cin >> v;
std::cout << query(1, id[v]) << '\n';
} else {
int z;
std::cin >> z;
std::vector<PII> t = query(1, id[z], id[z] + sz[z] - 1);
for (auto &[v, x, k]: t)
modify(1, id[v], id[v] + sz[v] - 1, -x, -k, dep[v]);
}
}
}
int main(void) {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
int _ = 1;
//std::cin >> _;
while (_ --) solve();
return 0;
}