2023湖南省赛 E.ytree (线段树)

传送门

大致思路:

1. 将操作1拆分为两个部分x(-1)^d + kd(-1)^d 。对于操作1中的x(-1)^d部分而言。我们可以对式子进行拆分,把x拆出来,我们会发现和v号点距离为奇数的点会减去x,为偶数的点会加上x,所以我们可以在线段树上用一个sum1维护应该减去的值,sum2维护加上的值即可。

2. 随即就是如何维护线段树不同结点之间的sum1和sum2了。我们将整棵树按照dfs序建树,如此一来一颗子树的dfs序是会一段连续的区间,我们在线段树上维护结点的深度最小值mn,当我们将父节点fa上的标记下传到子节点son的时候就可以根据父节点和子节点的最小深度差来下传标记,如果son.mn - fa.mn是奇数,那么son这个结点应该加上的其实是在fa减去的x的值总和,应该减去的其实是在fa加上的x的值的总和,所以就应该这样更新son.sum1 += fa.sum2, son.sum1 += fa.sum2。如果是偶数同理推导

3.再看1中的kd(-1)^d部分如何维护。维护k1和k2两个值,分别表示应该减去的k的总和 和 应该加上的k的总和。还是考虑如何下传标记。当我们将父节点fa上的标记下传到子节点son的时候同样可以根据父节点和子节点的最小深度差来下传标记。如果son.mn - fa.mn = d是奇数,那么son这个结点的sum1应该加上d * fa.k2, sum2应该加上d * fa.k1, son.k1 += fa.k2, son.k2 += fa.k1。偶数同理推导。

  在d为奇数的时候,将2和3中的两个式子合并就是son.sum1 = son.sum1 + fa.sum2 + fa.k2 * d, son.sum2 = son.sum2 + fa.sum1 + fa.k1 * d。

4. 操作2就是线段树的单点查询。

5. 操作3我们可以创建一个multiset<array<int, 3>>将每个操作1按照{v的dfs序, x, k}的顺序丢进set。当我们遇到操作3中的v的时候只需要调用set的lowerbound来查找v的dfs序第一个出现的位置,并将这个提取进行修改操作,修改完之后从multiset删除即可。

#include <bits/stdc++.h>

const int N = 2e5 + 10;
const int MOD = 1e9 + 7;
using ll = long long;
typedef std::array<int, 3> PII;

int n, m;
#define ls u << 1
#define rs u << 1 | 1

int w[N], h[N], e[N], ne[N], idx;
int id[N], cnt;
int dep[N], sz[N], dfn[N];

struct Node {
	int l, r;                                                                                                               
	int k1, k2;//k1奇数k的和,k2偶数k的和
	int mn;//dfs序最小的点的深度
	int sum1, sum2;//sum1奇数和, sum2偶数和
}tr[N << 2];

inline void add(int &x) {
	if (x >= MOD) x -= MOD;
	x += MOD;
	if (x >= MOD) x -= MOD;
}

inline void pushup(int u) {
	tr[u].mn = std::min(tr[ls].mn, tr[rs].mn);
}

inline void pushdown(int u) {
	int d1 = tr[ls].mn - tr[u].mn;
	if (d1 & 1) {
		tr[ls].sum1 = (1ll * tr[ls].sum1 + tr[u].sum2 + 1ll * tr[u].k2 * d1) % MOD;
		tr[ls].sum2 = (1ll * tr[ls].sum2 + tr[u].sum1 + 1ll * tr[u].k1 * d1) % MOD;
		tr[ls].k1 += tr[u].k2;
		tr[ls].k2 += tr[u].k1;
		
		add(tr[ls].sum2), add(tr[ls].sum1), add(tr[ls].k1), add(tr[ls].k2);
	} else {
		tr[ls].sum1 = (1ll * tr[ls].sum1 + tr[u].sum1 + 1ll * tr[u].k1 * d1) % MOD;
		tr[ls].sum2 = (1ll * tr[ls].sum2 + tr[u].sum2 + 1ll * tr[u].k2 * d1) % MOD;
		tr[ls].k1 += tr[u].k1;
		tr[ls].k2 += tr[u].k2;
		add(tr[ls].sum2), add(tr[ls].sum1), add(tr[ls].k1), add(tr[ls].k2);
	}
	
	int d2 = tr[rs].mn - tr[u].mn;
	
	if (d2 & 1) {
		tr[rs].sum1 = (1ll * tr[rs].sum1 + tr[u].sum2 + 1ll * tr[u].k2 * d2) % MOD;
		tr[rs].sum2 = (1ll * tr[rs].sum2 + tr[u].sum1 + 1ll * tr[u].k1 * d2) % MOD;
		tr[rs].k1 += tr[u].k2;
		tr[rs].k2 += tr[u].k1;
		add(tr[rs].sum1), add(tr[rs].sum2), add(tr[rs].k1), add(tr[rs].k2);
	} else {
		tr[rs].sum1 = (1ll * tr[rs].sum1 + tr[u].sum1 + 1ll * tr[u].k1 * d2) % MOD;
		tr[rs].sum2 = (1ll * tr[rs].sum2 + tr[u].sum2 + 1ll * tr[u].k2 * d2) % MOD;
		tr[rs].k1 += tr[u].k1;
		tr[rs].k2 += tr[u].k2;
		add(tr[rs].sum1), add(tr[rs].sum2), add(tr[rs].k1), add(tr[rs].k2);
	}
	
	tr[u].sum1 = tr[u].sum2 = tr[u].k1 = tr[u].k2 = 0;
}

inline void build(int u, int l, int r){
	tr[u] = {l, r};
	if(l == r)	{
		tr[u].mn = dep[dfn[l]];
		return ;
	}
	int mid = l + r >> 1;
	build(u << 1, l, mid);
	build(u << 1 | 1, mid + 1, r);
	
	pushup(u);
}

inline void init(){
	memset(h, -1, sizeof h);
}

inline void add(int a, int b){
    e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}
int rr[N];
inline void dfs(int u, int father, int depth){
    dep[u] = depth, id[u] = ++ cnt, sz[u] = 1;
    dfn[cnt] = u;
    for (int i = h[u]; ~i; i = ne[i]){
        int j = e[i];
        if (j == father) continue;
        dfs(j, u, depth + 1);
        sz[u] += sz[j];
    }
    rr[u] = cnt;
}

inline void modify(int u, int L, int R, int x, int k, int depth) {
	if (tr[u].l >= L && tr[u].r <= R) {
		int d = tr[u].mn - depth;
		if (d & 1) {
			tr[u].sum1 = (ll(tr[u].sum1) + x + 1ll * d * k) % MOD;
			tr[u].k1 += k;
			add(tr[u].sum1);
			add(tr[u].k1);
		} else {
			tr[u].sum2 = (ll(tr[u].sum2) + x + 1ll * d * k) % MOD;
			tr[u].k2 += k;
			add(tr[u].sum2);
			add(tr[u].k2);
		}
 		return ;
	}
	
	pushdown(u);
	
	int mid = tr[u].l + tr[u].r >> 1;
	if (L <= mid) modify(ls, L, R, x, k, depth);
	if (R > mid) modify(rs, L, R, x, k, depth);
}

inline int query(int u, int x) {
	if (tr[u].l == tr[u].r) return (tr[u].sum2 - tr[u].sum1 + MOD) % MOD;
	
	pushdown(u);
	
	int mid = tr[u].l + tr[u].r >> 1;
	
	if (x <= mid) return query(ls, x);
	return query(rs, x);
}

inline void solve() {
	memset(h, -1, sizeof h);
	std::cin >> n >> m;
	
	for (int i = 2; i <= n; i ++) {
		int x;
		std::cin >> x;
		add(x, i);
	}
	
	dfs(1, -1, 1);
	
	std::multiset<PII> st;

	build(1, 1, n);
	constexpr int INF = 0x3f3f3f3f;
	auto get = [&](int sb) {
		auto it = st.lower_bound({sb, -INF, -INF});
		if (it == st.end()) {
			PII sn = {10000000, 0, 0};
			return sn;
		}
		return *it;
	};
	
	while (m --) {
		int op;
		std::cin >> op;
		if (op == 1) {
			int x, v, k;
			std::cin >> v >> x >> k;
			modify(1, id[v], id[v] + sz[v] - 1, x, k, dep[v]);
			st.insert({id[v], x, k});
		} else if (op == 2){
			int v;
			std::cin >> v;
			std::cout << query(1, id[v]) << '\n';
		} else {
			int z;
			std::cin >> z;
			for (int t = get(id[z])[0]; t <= rr[z]; t = get(t)[0]) {
				auto [dfnn, x, k] = get(t);
				modify(1, dfnn, dfnn + sz[dfn[dfnn]] - 1, -x, -k, dep[dfn[dfnn]]);
				st.erase(st.find({dfnn, x, k}));
			}
		}
	}
}

signed main(void) {
	std::ios::sync_with_stdio(false);
	std::cin.tie(nullptr);
	std::cout.tie(nullptr);
	
	int _ = 1;
	
	//std::cin >> _;
	while (_ --) solve();
	
	return 0;
}

说一个可能比较难发现会错的写法。线段树上搞一个标记永久化。

LYL告诉我对于操作三标记永久化就行了,操作3回退的时候定位到区间将标记永久化移除,当时也是脑子抽了居然觉得很对。然后我就多开了一个sum3和sum4,k3和k4来标记永久化,还有tag来打标记。但是线段树上会出现这么一个问题,对于区间[1, 9]和区间[1, 10]其实都是在子区间[x1, y1], [x2, y2]这样的区间上打了标记,但是要回退的时候根本无法区分是[1, 10]在[x1, y1]上打的标记还是[1, 9]在[x1, y1]上打的标记。

#include <bits/stdc++.h>

const int N = 2e5 + 10;
const int MOD = 1e9 + 7;
using ll = long long;
typedef std::array<int, 3> PII;

int n, m;
#define ls u << 1
#define rs u << 1 | 1

int w[N], h[N], e[N], ne[N], idx;
int id[N], cnt;
int dep[N], sz[N], dfn[N];

struct node {
	int l, r, tag;
	std::vector<PII> tmp; 
}trr[N << 2];

struct Node {
	int l, r;                                                                                                
	int k1, k2;//k1奇数k的和,k2偶数k的和
	int mn;//dfs序最小的点的深度
	int sum1, sum2, sum3, sum4;//sum1奇数和, sum2偶数和
	int k3, k4;
	int tag;
}tr[N << 2];

inline void add(int &x) {
    if (x >= MOD) x -= MOD;
	x += MOD;
	if (x >= MOD) x -= MOD;
}

inline void pushup(int u) {
	tr[u].mn = std::min(tr[ls].mn, tr[rs].mn);
}

inline void pushdown(int u) {
	int d1 = tr[ls].mn - tr[u].mn;
	if (d1 & 1) {
		tr[ls].sum1 = (1ll * tr[ls].sum1 + tr[u].sum2 + 1ll * tr[u].k2 * d1) % MOD;
		tr[ls].sum2 = (1ll * tr[ls].sum2 + tr[u].sum1 + 1ll * tr[u].k1 * d1) % MOD;
		tr[ls].k1 += tr[u].k2;
		tr[ls].k2 += tr[u].k1;
		add(tr[ls].sum2), add(tr[ls].sum1), add(tr[ls].k1), add(tr[ls].k2);
	} else {
		tr[ls].sum1 = (1ll * tr[ls].sum1 + tr[u].sum1 + 1ll * tr[u].k1 * d1) % MOD;
		tr[ls].sum2 = (1ll * tr[ls].sum2 + tr[u].sum2 + 1ll * tr[u].k2 * d1) % MOD;
		tr[ls].k1 += tr[u].k1;
		tr[ls].k2 += tr[u].k2;
		add(tr[ls].sum2), add(tr[ls].sum1), add(tr[ls].k1), add(tr[ls].k2);
	}
	
	int d2 = tr[rs].mn - tr[u].mn;
	
	if (d2 & 1) {
		tr[rs].sum1 = (1ll * tr[rs].sum1 + tr[u].sum2 + 1ll * tr[u].k2 * d2) % MOD;
		tr[rs].sum2 = (1ll * tr[rs].sum2 + tr[u].sum1 + 1ll * tr[u].k1 * d2) % MOD;
		tr[rs].k1 += tr[u].k2;
		tr[rs].k2 += tr[u].k1;
		add(tr[rs].sum1), add(tr[rs].sum2), add(tr[rs].k1), add(tr[rs].k2);
	} else {
		tr[rs].sum1 = (1ll * tr[rs].sum1 + tr[u].sum1 + 1ll * tr[u].k1 * d2) % MOD;
		tr[rs].sum2 = (1ll * tr[rs].sum2 + tr[u].sum2 + 1ll * tr[u].k2 * d2) % MOD;
		tr[rs].k1 += tr[u].k1;
		tr[rs].k2 += tr[u].k2;
		add(tr[rs].sum1), add(tr[rs].sum2), add(tr[rs].k1), add(tr[rs].k2);
	}
	
	tr[u].sum1 = tr[u].sum2 = tr[u].k1 = tr[u].k2 = 0;
}

inline void build(int u, int l, int r){
	trr[u] = {l, r};
	tr[u] = {l, r};
	if(l == r)	{
		tr[u].mn = dep[dfn[l]];
		return ;
	}
	int mid = l + r >> 1;
	build(u << 1, l, mid);
	build(u << 1 | 1, mid + 1, r);
	
	pushup(u);
}

inline void init(){
	memset(h, -1, sizeof h);
}

inline void add(int a, int b){
    e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}

inline void dfs(int u, int father, int depth){
    dep[u] = depth, id[u] = ++ cnt, sz[u] = 1;
    dfn[cnt] = u;
    for (int i = h[u]; ~i; i = ne[i]){
        int j = e[i];
        if (j == father) continue;
        dfs(j, u, depth + 1);
        sz[u] += sz[j];
    }
}

inline void modify(int u, int L, int R, int x, int k, int depth) {
	if (tr[u].l >= L && tr[u].r <= R) {
		int d = tr[u].mn - depth;
		if (d & 1) {
			tr[u].sum1 = (ll(tr[u].sum1) + x + 1ll * d * k) % MOD;
			tr[u].k1 += k;
			add(tr[u].sum1);
			add(tr[u].k1);
		} else {
			tr[u].sum2 = (ll(tr[u].sum2) + x + 1ll * d * k) % MOD;
			tr[u].k2 += k;
			add(tr[u].sum2);
			add(tr[u].k2);
		}
 		return ;
	}
	
	pushdown(u);
	
	int mid = tr[u].l + tr[u].r >> 1;
	if (L <= mid) modify(ls, L, R, x, k, depth);
	if (R > mid) modify(rs, L, R, x, k, depth);
}

inline void pushdown1(int u) {
	if (trr[u].tag) {
		trr[ls].tag = trr[rs].tag = 1;
		trr[ls].tmp.clear(), trr[rs].tmp.clear();
		trr[u].tag = 0;
	}
}

inline std::vector<PII> query(int u, int L, int R) {
	if (trr[u].l >= L && trr[u].r <= R) {
		trr[u].tag = 1;
		std::vector<PII> now = trr[u].tmp;
		trr[u].tmp.clear();
 		return now;
	}
	
	pushdown1(u);
	
	std::vector<PII> cur, cur1, cur2;
	int mid = trr[u].l + trr[u].r >> 1;
	if (L <= mid) cur1 = query(ls, L, R);
	if (R > mid) cur2 = query(rs, L, R);
	cur.insert(cur.end(), cur1.begin(), cur1.end());
	cur.insert(cur.end(), cur2.begin(), cur2.end());
	return cur;
}

inline void modify(int u, int x, int v, int xx, int k) {
	trr[u].tmp.push_back({v, xx, k});
	if (trr[u].l == trr[u].r) return ;
	
	pushdown1(u);
	
	int mid = trr[u].l + trr[u].r >> 1;
	if (x <= mid) modify(ls, x, v, xx, k);
	else modify(rs, x, v, xx, k);
}

inline int query(int u, int x) {
	if (tr[u].l == tr[u].r) return (tr[u].sum2 % MOD - tr[u].sum1 % MOD + MOD) % MOD;
	
	pushdown(u);
	
	int mid = tr[u].l + tr[u].r >> 1;
	
	if (x <= mid) return query(ls, x);
	return query(rs, x);
}

inline void solve() {
	memset(h, -1, sizeof h);
	std::cin >> n >> m;
	
	for (int i = 2; i <= n; i ++) {
		int x;
		std::cin >> x;
		add(x, i);
	}
	
	dfs(1, -1, 1);

	build(1, 1, n);
	
	while (m --) {
		int op;
		std::cin >> op;
		if (op == 1) {
			int x, v, k;
			std::cin >> v >> x >> k;
			modify(1, id[v], id[v] + sz[v] - 1, x, k, dep[v]);
			modify(1, id[v], v, x, k);
		} else if (op == 2){
			int v;
			std::cin >> v;
			std::cout << query(1, id[v]) << '\n';
		} else {
			int z;
			std::cin >> z;
			std::vector<PII> t = query(1, id[z], id[z] + sz[z] - 1);
			for (auto &[v, x, k]: t)
				modify(1, id[v], id[v] + sz[v] - 1, -x, -k, dep[v]);
			
		}
	}
}

int main(void) {
	std::ios::sync_with_stdio(false);
	std::cin.tie(nullptr);
	std::cout.tie(nullptr);
	
	int _ = 1;
	
	//std::cin >> _;
	while (_ --) solve();
	
	return 0;
}
posted @ 2023-09-24 15:07  春始于雪之下  阅读(76)  评论(0编辑  收藏  举报