牛客算法周周练15D - 树上求和(DFS序 + 线段树维护平方和)

在这里插入图片描述在这里插入图片描述

题目大意:

输入n q 表示树有 n 个节点和 q 次操作,然后输入n个数,表示 n 个节点的权值,之后输入 n - 1 条边,再输入 q 次操作,1 x y表示给 x 所在的子树权值 + y,2 x 则输出 x所在子树的所有节点的平方和

解题思路:

树上问题的模板题,遇到树,而且是在子树上进行操作,先跑一遍dfs序将这棵树线性化,因为dfs序跑完以后,子数部分一定是连续的,而且可以求出该节点所在子树的起始时间戳和终止时间戳,然后子树上的问题就转化为对区间进行操作的问题了,维护一个平方和,愉快的使用线段树即可。

关于dfs序:
这道题需要先跑dfs序线性化,在dfs序中有这样几个值,in数组,in[x] 表示x节点进入序列的时间戳,out[x] 表示x从序列中出去的时间戳,in[x] -> out[x]即可表示 x 的子树,还有一个数组dfsn 表示遍历顺序,也就是树线性化后的数组。

关于线段树维树平方和问题:
该题维护的是平方和,也就是对于每个节点,是要变为 (a+y)2
而不是a2+y2,推一个公式:
平方和(a+b)2 = a2 + 2ab + b2,然而对于这道题来说,就可以转换为
Σa2 + 2 ·Σa · len · y + Σy2
,len表示区间长度。单点的原始值则是(Σa + len * b),lazy中存y的值,sum1 表示区间原始值, sum2 表示区间平方和, 推的时候一定要先推平方和,因为平方和是由原始值转化过来的,如果先推原始值会出错。这些都求完后,套用线段树模板即可,注意数据范围适当开long long 和随时对23333取模,因为这个点调试了一晚上加一上午才调好。AC代码:

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
#include <cmath>
using namespace std;
const int N = 1e5+50;
const int mod = 23333;
typedef long long ll;
ll lazy[N << 2] = {0}, sum1[N << 2] = {0}, sum2[N << 2] = {0};
int tot = 0, n, q;
int in[N], out[N], dfsn[N], a[N];
vector<int > v[N];
void dfsx(int x, int pre)
{
	in[x] = ++tot;
	dfsn[tot] = x;
	for (int i = 0; i < v[x].size(); i ++)
	{
		int j = v[x][i];
		if (j != pre)
		  dfsx(j,x);
	}
	out[x] = tot;
}
void pushup(int node)
{
	sum1[node] = (sum1[node << 1] + sum1[node << 1|1]) % mod;
	sum2[node] = (sum2[node << 1] + sum2[node << 1|1]) % mod;
}
void build(int node, int l, int r)
{
	if (l == r)
	{
		sum1[node] = a[dfsn[l]] % mod;
		sum2[node] = (sum1[node] * sum1[node]) % mod;//这里玄学,用a[dfsn[l]] % mod会WA,用sum1[node] 则不会..
		return;
 	}
	int mid = (l + r) >> 1;
	build(node << 1, l, mid);
	build(node << 1|1, mid+1, r);
	pushup(node);
}
void pushdown(int node, int l, int r)//利用推好的公式还原相应的值
{
	if (!lazy[node]) return;
	int mid = (l + r) >> 1;
	lazy[node << 1] = (lazy[node << 1] + lazy[node]) % mod;
	lazy[node << 1|1] = (lazy[node << 1|1] + lazy[node]) % mod;
	sum2[node << 1] = (sum2[node << 1] + (lazy[node] * lazy[node] * (mid - l + 1)) % mod + (2 * lazy[node] % mod * sum1[node << 1] % mod) % mod) % mod;
	sum2[node << 1|1] = (sum2[node << 1|1] + (lazy[node] * lazy[node] * (r - mid)) % mod + (2 * lazy[node] % mod * sum1[node << 1|1] % mod) % mod) % mod;
	sum1[node << 1] = (sum1[node << 1] + (mid - l + 1) * lazy[node]) % mod;
	sum1[node << 1|1] = (sum1[node << 1|1] + (r - mid) * lazy[node]) % mod;
	lazy[node] = 0;
}
void update(int node ,int l, int r, int L, int R,int val)
{
	if (l >= L && r <= R)
	{
		lazy[node] = (lazy[node] % mod + val) % mod;
		sum2[node] = (sum2[node] + val % mod * val % mod * (r - l + 1) % mod + 2 * val % mod * sum1[node] % mod) % mod;
		sum1[node] = (sum1[node] + (r - l + 1) * val % mod) % mod;
		return;
	}
	pushdown(node, l, r);
	int mid = (l + r) >> 1;
	if (L <= mid) 
	  update(node << 1, l, mid, L, R, val);
	if (R > mid)
	  update(node << 1|1, mid + 1, r, L, R, val);
	pushup(node);
}
ll query(int node, int l, int r, int L, int R)
{
	if (l >= L && r <= R)
	  return sum2[node] % mod;
	pushdown(node, l, r);
	int mid = (l + r) >> 1;
	ll s1 = 0, s2 = 0;
	if (L <= mid)
	  s1 += query(node << 1, l, mid, L, R) % mod;
	if (R > mid)
	  s2 += query(node << 1|1, mid + 1, r, L, R) % mod;
	return (s1 + s2) % mod;
}
int main()
{
	cin >> n >> q;
	for (int i = 1; i <= n; i ++)
		cin >> a[i];
	for(int i = 1; i < n; i ++)
	{
		int e1, e2;
		cin >> e1 >> e2;
		v[e1].push_back(e2);
		v[e2].push_back(e1);
	}
	dfsx(1, 0);
	build(1, 1, n);
 	while (q--)
	{
		int cmd, b, c;
		cin >> cmd;
		if (cmd == 1)
		{
			cin >> b >> c;
			c %= mod;
			update(1, 1, n, in[b], out[b], c);
		}
		else
		{
			cin >> b;
			ll ans = query(1, 1, n, in[b], out[b]);
			cout << ans << endl;
		}
	}
	return 0;
}
posted @ 2020-07-15 14:46  Hayasaka  阅读(61)  评论(0编辑  收藏  举报