线段树

线段树

区间加,区间和

#include <iostream>
#include <cstring>
using namespace std;
typedef long long ll;
const int N = 1e5 + 10;
ll w[N];
struct Node
{
    int l, r;
    ll add, sum;
}tr[N*4];

int n, m;


void pushup(int u)
{
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

void pushdown(int u)
{
    Node &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
    if (root.add)
    {
        left.sum += root.add * (left.r - left.l + 1), right.sum += root.add * (right.r - right.l + 1);
        left.add += root.add, right.add += root.add;
        root.add = 0;
    }
    
}
void build(int u, int l, int r)
{
    if (l == r)
    {
        tr[u] = {l, r, 0, w[l]};
        return;
    }
    tr[u] = {l, r};
    int mid = l + r >> 1;
    build(u << 1, l, mid);
    build(u << 1 | 1, mid + 1, r);
    pushup(u);
}

void modify(int u, int l, int r, ll k)
{
    if (tr[u].l >= l && tr[u].r <= r)
    {
        tr[u].sum += k * (tr[u].r - tr[u].l + 1);
        tr[u].add += k;
        return;
    }
    int mid = tr[u].l + tr[u].r >> 1;
    pushdown(u);
    if (l <= mid)
        modify(u << 1, l, r, k);
    if (r > mid)
        modify(u << 1 | 1, l, r, k);
    pushup(u);
}

ll query(int u, int l, int r)
{
    if (tr[u].l >= l && tr[u].r <= r)
        return tr[u].sum;
    int mid = tr[u].l + tr[u].r >> 1;
    ll v = 0;
    pushdown(u);
    if (l <= mid)
        v = query(u << 1, l, r);
    if (r > mid)
        v += query(u << 1 | 1, l, r);
    return v;
}

区间替换,区间和

#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
typedef long long ll;
const int N = 1e5 + 10;
int n;
struct Node
{
	int l, r;
	ll c, lazy;
}tr[N<<2];

void pushdown(int u)
{
	Node &root = tr[u], &left = tr[u<<1], &right = tr[u<<1|1];
	if (root.lazy)
	{
		left.c = left.lazy = root.lazy;
		right.c = right.lazy = root.lazy;
		root.lazy = 0;
	}
}
void build(int u, int l, int r)
{
	tr[u] = {l, r, 1, 0};
	if (l == r)
		return;
	int mid = l + r >> 1;
	build(u << 1, l, mid);
	build(u << 1 | 1, mid + 1, r);
}

void modify(int u, int l, int r, ll k)
{
	if (tr[u].l >= l && tr[u].r <= r)
	{
		tr[u].c = tr[u].lazy = k;
		return;
	}
	pushdown(u);
	int mid = tr[u].l + tr[u].r >> 1;
	if (l <= mid)
		modify(u << 1, l, r, k);
	if (r > mid)
		modify(u << 1 | 1, l, r, k);
}

ll query(int u, int l, int r)
{
	if (tr[u].l >= l && tr[u].r <= r)
		return tr[u].c;
	pushdown(u);
	int mid = tr[u].l + tr[u].r >> 1;
	if (l <= mid)
		return query(u << 1, l, r);
	else
		return query(u << 1 | 1, l, r);
}

区间加,区间乘,区间和

#include <iostream>
#include <cstring>
#define ll long long
using namespace std;
const int N = 100010;
int n, m, p;
int w[N];
struct Node
{
	int l, r;
	ll add, mult, sum;
}tr[N*4];
void eqal(Node &t, ll add, ll mult)
{
	t.sum = (t.sum * mult % p + add * (t.r -  t.l + 1) % p) % p;
	t.mult = t.mult * mult % p;
	t.add = (t.add * mult + add) % p;
}
void pushup(int u)
{
	tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % p;
}
void pushdown(int u)
{
	Node &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
	eqal(left, root.add, root.mult);
	eqal(right, root.add, root.mult);
	root.add = 0, root.mult = 1;
}
void build(int u, int l, int r)
{
	if (l == r)
	{
		tr[u] = {l, r, 0, 1, w[l]};
		return;
	}
	tr[u] = {l, r, 0, 1, 0};
	int mid = l + r >> 1;
	build(u << 1, l, mid);
	build(u << 1 | 1, mid + 1, r);
	pushup(u);
}
void modify(int u, int l, int r, ll add, ll mult)
{
	if (tr[u].l >= l && tr[u].r <= r)
	{
		eqal(tr[u], add, mult);
		return;
	}
	pushdown(u);
	int mid = tr[u].l + tr[u].r >> 1;
	if (l <= mid)
		modify(u << 1, l, r, add, mult);
	if (r > mid)
		modify(u << 1 | 1, l, r, add, mult);
	pushup(u);
}
ll query(int u, int l, int r)
{
	if (tr[u].l >= l && tr[u].r <= r)
	{
		return tr[u].sum % p;
	}
	pushdown(u);
	int mid = tr[u].l + tr[u].r >> 1;
	ll sum = 0;
	if (l <= mid)
		sum = query(u << 1, l, r) % p;
	if (r > mid)
		sum = (sum + query(u << 1 | 1, l ,r)) % p;
	pushup(u);
	return sum % p;
}
int main()
{
	ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
	cin >> n >> m >> p;
	for (int i = 1; i <= n; i++)
		cin >> w[i];
	build(1, 1, n);
	while(m--)
	{
		int op, l, r, k;
		cin >> op;
		if (op == 1)
		{
			cin >> l >> r >> k;
			modify(1, l, r, 0, k);
		}
		else if (op == 2)
		{
			cin >> l >> r >> k;
			modify(1, l, r, k, 1);
		}
		else if (op == 3)
		{
			cin >> l >> r;
			cout << query(1, l, r) << endl;
		}
	}
	return 0;
}
posted @ 2022-05-16 23:43  hzy0227  阅读(18)  评论(0编辑  收藏  举报