线段树

线段树

区间加 ,区间和,区间最值,区间推平

#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
#include <cmath>
#include <cstdio>
using namespace std;
typedef long long ll;
typedef pair<ll, ll> PLL;
const int N = 1e5 + 10;
const ll INF = 1e18;

ll h[N];
struct Node
{
	int l, r;	   //区间[l,r]
	ll add;	   //区间的延时标记
	ll sum;	   //区间和
	ll mx;		   //区间最大值
	ll mn;		   //区间最小值
}tr[N << 2]; //一定要开到4倍多的空间

void pushup(int u)
{
	tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
	tr[u].mx = max(tr[u << 1].mx, tr[u << 1 | 1].mx);
	tr[u].mn = min(tr[u << 1].mn, tr[u << 1 | 1].mn);
}
void pushdown(int u)
{
	//说明该区间之前更新过
	//要想更新该区间下面的子区间,就要把上次更新该区间的值向下更新
	if (tr[u].add)
	{
		//替换原来的值
		/*
		tr[u<<1].sum = (tr[u<<1].r-tr[u<<1].l+1)*tr[u].add;
		tr[u<<1|1].sum = (tr[u<<1|1].r-tr[u<<1|1].l+1)*tr[u].add;
		tr[u<<1].mx = tr[u].add;
		tr[u<<1|1].mx = tr[u].add;
		tr[u<<1].mn = tr[u].add;
		tr[u<<1|1].mn = tr[u].add;
		tr[u<<1].add = tr[u].add;
		tr[u<<1|1].add = tr[u].add;
		tr[u].add = 0;*/
		//在原来的值的基础上加上val

		tr[u << 1].sum += (tr[u << 1].r - tr[u << 1].l + 1) * tr[u].add;
		tr[u << 1 | 1].sum += (tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1) * tr[u].add;
		tr[u << 1].mx += tr[u].add;
		tr[u << 1 | 1].mx += tr[u].add;
		tr[u << 1].mn += tr[u].add;
		tr[u << 1 | 1].mn += tr[u].add;
		tr[u << 1].add += tr[u].add;
		tr[u << 1 | 1].add += tr[u].add;
		tr[u].add = 0;
	}
}
void build(int u, int l, int r)
{
	tr[u].l = l;
	tr[u].r = r;
	tr[u].add = 0; //刚开始一定要清0
	if (l == r)
	{
		tr[u].mn = tr[u].mx = tr[u].sum = h[l];
		return;
	}
	int mid = (l + r) >> 1;
	build(u << 1, l, mid);
	build(u << 1 | 1, mid + 1, r);
	pushup(u);
}
void modify(int u, int l, int r, ll val)
{
	if (l <= tr[u].l && r >= tr[u].r)
	{
		/*把原来的值替换成val,因为该区间有tr[u].r-tr[u].l+1
		个数,所以区间和 以及 最值为:
		*/
		/*tr[u].sum = (tr[u].r-tr[u].l+1)*val;
		tr[u].mn = val;
		tr[u].mx = val;
		tr[u].add = val;//延时标记*/
		//在原来的值的基础上加上val,因为该区间有tr[u].r-tr[u].l+1
		//个数,所以区间和 以及 最值为:
		tr[u].sum += (tr[u].r - tr[u].l + 1) * val;
		tr[u].mn += val;
		tr[u].mx += val;
		tr[u].add += val; //延时标记

		return;
	}
	pushdown(u);
	int mid = (tr[u].l + tr[u].r) >> 1;
	if (l <= mid)
	{
		modify(u << 1, l, r, val);
	}
	if (r > mid)
	{
		modify(u << 1 | 1, l, r, val);
	}
	pushup(u);
}
PLL query(int u, int l, int r)
{
	if (l <= tr[u].l && r >= tr[u].r)
	{
		// return tr[u].sum;
		return {tr[u].mx, tr[u].mn};
		// return tr[u].mn;
	}
	pushdown(u);
	int mid = (tr[u].l + tr[u].r) >> 1;
	// ll ans = 0;
	ll maxn = -INF;
	ll minn = INF;
	if (l <= mid)
	{
		// ans += query(l, r, u << 1);
		maxn = max(query(u << 1, l, r).first, maxn);
		minn = min(query(u << 1, l, r).second, minn);
	}
	if (r > mid)
	{
		// ans += query(l, r, u << 1 | 1);
		maxn = max(query(u << 1 | 1, l, r).first, maxn);
		minn = min(query(u << 1 | 1, l, r).second, minn);
	}
	// return ans;
	return {maxn, minn};
	// return minn;

区间加,区间乘,区间和

#include <iostream>
#include <cstring>
#define ll long long
using namespace std;
const int N = 100010;
int n, m, p;
int w[N];
struct Node
{
	int l, r;
	ll add, mult, sum;
}tr[N*4];
void eqal(Node &t, ll add, ll mult)
{
	t.sum = (t.sum * mult % p + add * (t.r -  t.l + 1) % p) % p;
	t.mult = t.mult * mult % p;
	t.add = (t.add * mult + add) % p;
}
void pushup(int u)
{
	tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % p;
}
void pushdown(int u)
{
	Node &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
	eqal(left, root.add, root.mult);
	eqal(right, root.add, root.mult);
	root.add = 0, root.mult = 1;
}
void build(int u, int l, int r)
{
	if (l == r)
	{
		tr[u] = {l, r, 0, 1, w[l]};
		return;
	}
	tr[u] = {l, r, 0, 1, 0};
	int mid = l + r >> 1;
	build(u << 1, l, mid);
	build(u << 1 | 1, mid + 1, r);
	pushup(u);
}
void modify(int u, int l, int r, ll add, ll mult)
{
	if (tr[u].l >= l && tr[u].r <= r)
	{
		eqal(tr[u], add, mult);
		return;
	}
	pushdown(u);
	int mid = tr[u].l + tr[u].r >> 1;
	if (l <= mid)
		modify(u << 1, l, r, add, mult);
	if (r > mid)
		modify(u << 1 | 1, l, r, add, mult);
	pushup(u);
}
ll query(int u, int l, int r)
{
	if (tr[u].l >= l && tr[u].r <= r)
	{
		return tr[u].sum % p;
	}
	pushdown(u);
	int mid = tr[u].l + tr[u].r >> 1;
	ll sum = 0;
	if (l <= mid)
		sum = query(u << 1, l, r) % p;
	if (r > mid)
		sum = (sum + query(u << 1 | 1, l ,r)) % p;
	pushup(u);
	return sum % p;
}
int main()
{
	ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
	cin >> n >> m >> p;
	for (int i = 1; i <= n; i++)
		cin >> w[i];
	build(1, 1, n);
	while(m--)
	{
		int op, l, r, k;
		cin >> op;
		if (op == 1)
		{
			cin >> l >> r >> k;
			modify(1, l, r, 0, k);
		}
		else if (op == 2)
		{
			cin >> l >> r >> k;
			modify(1, l, r, k, 1);
		}
		else if (op == 3)
		{
			cin >> l >> r;
			cout << query(1, l, r) << endl;
		}
	}
	return 0;
}
posted @ 2022-05-24 20:02  hzy0227  阅读(23)  评论(0编辑  收藏  举报