线段树

重构一下线段树的博客,关于线段树的相关定义以及证明过段时间再补
首先是一个简单线段树,这里叫做伪线段树,其实本质就是一个二叉树,仅能支持单点操作:

单点修改 + 区间查询

// 单点修改查询
// http://ybt.ssoier.cn:8088/problem_show.php?pid=1549
// https://www.luogu.com.cn/problem/P1198
// 用一维数组来存,当作完全二叉树来存
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 10;
long long int m, p, n, last, t;
char op;
struct node
{
    int l, r, v;
} tr[N * 4];
void pushup(int u) // 更新每个区间的最大值
{
    tr[u].v = max(tr[u * 2].v, tr[2 * u + 1].v);
}
void build(int u, int l, int r) // 建立线段树
{
    tr[u] = {l, r};
    if (l == r)
        return;
    int mid = l + r >> 1;
    build(2 * u, l, mid), build(2 * u + 1, mid + 1, r); // 向左建立,向右建立
}
int query(int u, int l, int r) // 查询
{
    if (tr[u].l >= l && tr[u].r <= r)
        return tr[u].v;
    int mid = tr[u].l + tr[u].r >> 1, v = 0;
    if (l <= mid)
        v = query(2 * u, l, r);
    if (r > mid)
        v = max(v, query(2 * u + 1, l, r));
    return v;
}
int modify(int u, int x, int v) // 修改
{
    if (tr[u].l == x && tr[u].r == x)
        tr[u].v = v;
    else
    {
        int mid = tr[u].l + tr[u].r >> 1;
        if (x <= mid)
            modify(2 * u, x, v);
        else
            modify(2 * u + 1, x, v);
        pushup(u);
    }
}
int main()
{
    cin >> m >> p;
    build(1, 1, m);
    while (m--)
    {
        cin >> op >> t;
        if (op == 'Q')
        {
            last = query(1, n - t + 1, n);
            cout << last << endl;
        }
        else
        {
            modify(1, n + 1, (last + t) % p);
            n++;
        }
    }
    return 0;
}

对于线段树的区间求和操作,我们只需要在线段树上对一个节点向上更新一下区间和就可以:

// //区间最大字段和
#include <bits/stdc++.h>
using namespace std;
const int N = 2e6 + 10;
int n, m, x, y, k, a[N];
struct node
{
    int l, r, all, lmax, rmax, sum;
} tr[N * 4];
void pushup(node &w, node &l, node &r)
{
    w.sum = l.sum + r.sum;
    w.lmax = max(l.lmax, l.sum + r.lmax), w.rmax = max(r.sum + l.rmax, r.rmax);
    w.all = max(max(l.all, r.all), l.rmax + r.lmax);
}
void pushup(int u)
{
    pushup(tr[u], tr[2 * u], tr[2 * u + 1]);
}
void build(int u, int l, int r)
{
    if (l == r)
        tr[u] = {l, r, a[l], a[l], a[l], a[l]};
    else
    {
        tr[u] = {l, r};
        int mid = l + r >> 1;
        build(2 * u, l, mid), build(2 * u + 1, mid + 1, r);
        pushup(u);
    }
}
void modify(int u, int x, int v)
{
    if (tr[u].l == x && tr[u].r == x)
        tr[u] = {x, x, v, v, v, v};
    else
    {
        int mid = tr[u].l + tr[u].r >> 1;
        if (x <= mid)
            modify(2 * u, x, v);
        else
            modify(2 * u + 1, x, v);
        pushup(u);
    }
}
node query(int u, int l, int r)
{
    if (tr[u].l >= l && tr[u].r <= r)
        return tr[u];
    int mid = tr[u].l + tr[u].r >> 1;
    if (r <= mid)
        return query(2 * u, l, r);
    else if (l > mid)
        return query(2 * u + 1, l, r);
    else
    {
        auto left = query(2 * u, l, r), right = query(2 * u + 1, l, r);
        node res;
        pushup(res, left, right);
        return res;
    }
}
int main()
{
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    build(1, 1, n);
    while (m--)
    {
        cin >> k >> x >> y;
        if (k == 1)
        {
            if (x > y)
                swap(x, y);
            cout << query(1, x, y).all << endl;
        }
        else
            modify(1, x, y);
    }
    return 0;
}

此外,线段树还能够支持区间的最大公约数,也就向上更新一下就可以了:

// 区间最大公约数
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 10;
char op;
int n, m, a[N];
struct node
{
    int l, r, sum, d;
} tr[4 * N];
void pushup(node &w, node &l, node &r)
{
    w.sum = l.sum + r.sum;
    w.d = __gcd(l.d, r.d);
}
void pushup(int u)
{
    pushup(tr[u], tr[2 * u], tr[2 * u + 1]);
}
void build(int u, int l, int r)
{
    if (l == r)
        tr[u] = {l, r, a[r] - a[r - 1], a[r] - a[r - 1]};
    else
    {
        tr[u] = {l, r};
        int mid = l + r >> 1;
        build(2 * u, l, mid), build(2 * u + 1, mid + 1, r);
        pushup(u);
    }
}
void modify(int u, int x, int v)
{
    if (tr[u].l == x && tr[u].r == x)
        tr[u] = {x, x, tr[u].sum + v, tr[u].sum + v};
    else
    {
        int mid = tr[u].l + tr[u].r >> 1;
        if (x <= mid)
            modify(2 * u, x, v);
        else
            modify(2 * u + 1, x, v);
        pushup(u);
    }
}
node query(int u, int l, int r)
{
    if (tr[u].l >= l && tr[u].r <= r)
        return tr[u];
    int mid = tr[u].l + tr[u].r >> 1;
    if (r <= mid)
        return query(2 * u, l, r);
    else if (l > mid)
        return query(2 * u + 1, l, r);
    else
    {
        auto left = query(2 * u, l, r), right = query(2 * u + 1, l, r);
        node res;
        pushup(res, left, right);
        return res;
    }
}
int main()
{
    int l, r, d;
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    build(1, 1, n);
    while (m--)
    {
        cin >> op >> l >> r;
        if (op == 'Q')
        {
            auto left = query(1, 1, l), right = query(1, l + 1, r);
            cout << abs(__gcd(left.sum, right.d)) << endl;
        }
        else
        {
            cin >> d;
            modify(1, l, d);
            if (r + 1 <= n)
                modify(1, r + 1, -d);
        }
    }
    return 0;
}

区间修改 + 区间查询

接下来是支持区间修改的,对于区间操作,我们可以做个标记,然后只修改父节点,如果要查询子节点的时候我们直接向下更新标记,子区间再进行修改即可
区间加模板

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1e6 + 10;
struct node
{
    int l, r, sum, lazy;
} tr[N * 4];
int a[N];
void pushdown(int u)
{
    auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
    left.sum += (left.r - left.l + 1) * root.lazy, left.lazy += root.lazy;
    right.sum += (right.r - right.l + 1) * root.lazy, right.lazy += root.lazy;
    root.lazy = 0;
}
void pushup(int u)
{
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void build(int u, int l, int r)
{
    if (l == r)
        tr[u] = {l, r, a[l], 0};
    else
    {
        tr[u] = {l, r};
        int mid = l + r >> 1;
        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
}
int query(int u, int l, int r)
{
    if (tr[u].l >= l && tr[u].r <= r)
        return tr[u].sum;
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1, sum = 0;
    if (l <= mid)
        sum = query(u << 1, l, r);
    if (r > mid)
        sum = sum + query(u << 1 | 1, l, r);
    return sum;
}
void modify(int u, int l, int r, int k)
{
    if (tr[u].l >= l && tr[u].r <= r)
    {
        tr[u].sum += (tr[u].r - tr[u].l + 1) * k;
        tr[u].lazy += k;
        return;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if (l <= mid)
        modify(u << 1, l, r, k);
    if (r > mid)
        modify(u << 1 | 1, l, r, k);
    pushup(u);
}
signed main()
{
    std::ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
    int n, m;
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    build(1, 1, n);
    while (m--)
    {
        int op, l, r, k;
        cin >> op >> l >> r;
        if (op == 1)
            cin >> k, modify(1, l, r, k);
        else
            cout << query(1, l, r) << endl;
    }
    return 0;
}

区间乘

区间乘模板

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 2e5 + 10;
int n, p, m, a[N], mod, x, y, z, k;
struct node
{
    int l, r, sum, add, mul;
} tr[4 * N];
void pushup(int u)
{
    tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % mod;
}
void val(node &t, int add, int mul)
{
    t.sum = (t.sum * mul + (t.r - t.l + 1) * add) % mod;
    t.mul = t.mul * mul % mod, t.add = (t.add * mul + add) % mod;
}
void pushdown(int u)
{
    val(tr[u << 1], tr[u].add, tr[u].mul), val(tr[u << 1 | 1], tr[u].add, tr[u].mul);
    tr[u].mul = 1, tr[u].add = 0;
}
void build(int u, int l, int r)
{
    if (l == r)
        tr[u] = {l, r, a[r], 0, 1};
    else
    {
        tr[u] = {l, r, 0, 0, 1};
        int mid = l + r >> 1;
        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
}
void modify(int u, int l, int r, int add, int mul)
{
    if (tr[u].l >= l && tr[u].r <= r)
        val(tr[u], add, mul);
    else
    {
        pushdown(u);
        int mid = tr[u].l + tr[u].r >> 1;
        if (l <= mid)
            modify(u << 1, l, r, add, mul);
        if (r > mid)
            modify(u << 1 | 1, l, r, add, mul);
        pushup(u);
    }
}
int query(int u, int l, int r)
{
    if (tr[u].l >= l && tr[u].r <= r)
        return tr[u].sum;
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1, res = 0;
    if (l <= mid)
        res = query(u << 1, l, r) % mod;
    if (r > mid)
        res = (res + query(u << 1 | 1, l, r)) % mod;
    return res;
}
signed main()
{
    cin >> n >> m >> p;
    mod = p;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    build(1, 1, n);
    while (m--)
    {
        cin >> z >> x >> y;
        if (z == 1)
            cin >> k, modify(1, x, y, 0, k);
        else if (z == 2)
            cin >> k, modify(1, x, y, k, 1);
        else
            cout << query(1, x, y) % mod << endl;
    }
    return 0;
}

区间最大公约数

在前面, 我们已经介绍了一种不带区间修改的线段树方法, 现在我们来讨论一下如何对于一个带修的线段树进行区间查询, 首先对于最大公约数, 我们有如下公式: \(\text{gcd}(a, b, c) = \text{gcd}(a, b - a, c - b)\)
一个显然的方法就是我们用差分的形式在线段树上表示出来, 所以这样我们就将一个区间修改变成了一个单点修改, 那么对于这个修改, 很明显我们可以直接将变化后的值与原来的最大公约数再做一遍 \(gcd\) 即可, 这里给出一个证明: 假设 \(\text{gcd}(a, b, c) = x\) 此时我们将 \(b\) 变成 \(b + k\), 那么有: \(\text{gcd}(a, b, c) = \text{gcd}(x, b)\)
这里的 \(b\) 是已经修改过的. 时间复杂度 \(O(mlog^2n)\)
至此就完成了, 后面在区间最值还会用到, 不再给出证明

区间异或

此外,线段树还支持区间修改以及区间异或,这里与之前有所不同,如果按照之前的做法,我们进行区间修改的操作是要具体到某个叶节点的,因为这样我们才能进行区间的整体修改,但是这样的复杂度很明显,对于具体的区间修改,我们的复杂度是
\(O(n)\) 的,考虑如何进行优化:
首先对于一个 \(01\) 串,如果我们按位取反,很明显我们的0会变成1,1会变成0,设我们原来有 \(x\) 个1,那么取反之后我们就有 \(size - x\) 个1
那么由此联想到,我们可以把线段树的每个叶节点的值拆分成二进制模式,然后对其位数进行操作,对于异或的x,我们把他拆位,如果某个位是1,那么就打个标记代表要进行取反,那么这个区间中这个位的总数就变成了区间长度-区间内的数\(\oplus\) 这个位=1的个数,对于求和操作,把每个位的个数乘上 \(2^i\) 然后相加即可
区间异或

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 2e5 + 10, mod = 1e9 + 7;
struct node
{
    int l, r, should[25], tag[25];
    //should是每个位为1的个数
    //tag是每个位是否需要取反
} tr[N << 1];
int a[N];
void pushup(int u)
{
    //统计区间每个位的总数
    for (int i = 1; i <= 21; i++)
        tr[u].should[i] = tr[u << 1].should[i] + tr[u << 1 | 1].should[i];
}
void pushdown(int u)
{
    auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
    for (int i = 1; i <= 21; i++)
    {
        if (root.tag[i]) //若需要取反,则取反
        {
            left.should[i] = (left.r - left.l + 1) - left.should[i];
            right.should[i] = (right.r - right.l + 1) - right.should[i];
        }
        left.tag[i] ^= root.tag[i];
        right.tag[i] ^= root.tag[i];
        root.tag[i] = 0;
    }
}
void build(int u, int l, int r)
{
    tr[u] = {l, r};
    if (l == r)
    {
        for (int i = 1; i <= 21; i++)
            if (a[l] & (1 << (i - 1)))
                tr[u].should[i] = 1; //标记为1
    }
    else
    {
        int mid = l + r >> 1;
        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
}
void modify(int u, int l, int r, int x)
{
    if (tr[u].l >= l && tr[u].r <= r)
    {
        for (int i = 1; i <= 21; i++)
        {
            if (x & (1 << (i - 1)))
                //区间长度-目前的位=1的个数
                tr[u].should[i] = (tr[u].r - tr[u].l + 1) - tr[u].should[i];
            tr[u].tag[i] ^= (x & (1 << (i - 1))) ? 1 : 0;
        }
        return;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if (l <= mid)
        modify(u << 1, l, r, x);
    if (r > mid)
        modify(u << 1 | 1, l, r, x);
    pushup(u);
}
int query(int u, int l, int r)
{
    if (tr[u].l >= l && tr[u].r <= r)
    {
        int ans = 0;
        for (int i = 1; i <= 21; i++)
        {
            int now = (1 << (i - 1));
            ans += (now * tr[u].should[i]);
        }
        return ans;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1, res = 0;
    if (l <= mid)
        res += query(u << 1, l, r);
    if (r > mid)
        res += query(u << 1 | 1, l, r);
    return res;
}
signed main()
{
    std::ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
    int n;
    cin >> n;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    build(1, 1, n);
    int m;
    cin >> m;
    while (m--)
    {
        int op;
        cin >> op;
        if (op == 1)
        {
            int l, r;
            cin >> l >> r;
            cout << query(1, l, r) << '\n';
        }
        else
        {
            int x, l, r;
            cin >> l >> r >> x;
            modify(1, l, r, x);
        }
    }
    return 0;
}

同时进行操作

对于区间同时进行多种操作, 我们考虑添加多个懒标记进行操作, 同时要注意次序问题, 这里拿出 Hdu 4578 举例:
题目要求区间加, 乘, 修改成相同的数, 我们考虑用 \(3\)\(Lazy - tag\) 来进行标记, 这里主要难点在于求出区间的和, 平方和, 立方和
考虑推公式:

代码:

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1e6 + 10, mod = 10007;
struct tree{
    int l, r, tag[3], sum[3];
}tr[N << 1];
int a[N];
void pushup(int u){
    tr[u].sum[0] = (tr[u << 1].sum[0] + tr[u << 1 | 1].sum[0]) % mod;
    tr[u].sum[1] = (tr[u << 1].sum[1] + tr[u << 1 | 1].sum[1]) % mod;
    tr[u].sum[2] = (tr[u << 1].sum[2] + tr[u << 1 | 1].sum[2]) % mod;
}
void make1(tree &x, int val){
    x.tag[0] = 0; 
    x.tag[1] = 1;
    x.tag[2] = val;
    x.sum[0] = (x.r - x.l + 1) * x.tag[2] % mod;
    x.sum[1] = (x.r - x.l + 1) * x.tag[2] % mod * x.tag[2] % mod;
    x.sum[2] = (x.r - x.l + 1) * x.tag[2] % mod * x.tag[2] % mod * x.tag[2] % mod;
}
// ok, let's fix everything now
void make2(tree &x, int add){ 
    // 这里注意, 要倒序更新, 类似与背包, 上一个更改的值会影响下一个的值
    x.sum[2] = ((x.sum[2] + (x.r - x.l + 1) * add % mod * add % mod * add % mod) % mod + 
                (3 * x.sum[0] % mod * add % mod * add % mod + 3 * x.sum[1] % mod * add % mod) % mod) % mod;
    x.sum[1] = ((x.sum[1] + (x.r - x.l + 1) * add % mod * add % mod) % mod + 2 * x.sum[0] % mod * add % mod) % mod;
    x.sum[0] = ((x.r - x.l + 1) * add % mod + x.sum[0] % mod) % mod;
    x.tag[0] = (x.tag[0] + add) % mod;
}   
void make3(tree &x, int mul){
    x.sum[0] = x.sum[0] * mul % mod;
    x.sum[1] = x.sum[1] * mul % mod * mul % mod;
    x.sum[2] = x.sum[2] * mul % mod * mul % mod * mul % mod;
    x.tag[0] = x.tag[0] * mul % mod;
    x.tag[1] = x.tag[1] * mul % mod;
}
void pushdown(int u){
    // 这里的顺序也不可变 , 由于我们后面的区间加操作与求和相关, 所以如果有相乘操作要先进行操作
    if(tr[u].tag[2] != 0){
        make1(tr[u << 1], tr[u].tag[2]);
        make1(tr[u << 1 | 1], tr[u].tag[2]);
        tr[u].tag[2] = 0;
    }
    if(tr[u].tag[1] != 1){
        make3(tr[u << 1], tr[u].tag[1]);
        make3(tr[u << 1 | 1], tr[u].tag[1]);
        tr[u].tag[1] = 1;
    }
    if(tr[u].tag[0] != 0){
        make2(tr[u << 1], tr[u].tag[0]);
        make2(tr[u << 1 | 1], tr[u].tag[0]);
        tr[u].tag[0] = 0;
    }
}
void build(int u, int l, int r ){
    if(l == r){
        tr[u] = {l, r, {0, 1, 0}, {0, 0, 0}};
        return;
    }
    tr[u] = {l, r, {0, 1, 0}, {0, 0, 0}};
    int mid = l + r >> 1;
    build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
    pushup(u);
}
void modify(int u, int l, int r, int add, int mul, int change){
    if(tr[u].l >= l && tr[u].r <= r){
        if(change != 0){
            make1(tr[u], change);
        }
        else if(add != 0){
            make2(tr[u], add);
        }
        else if(mul != 1){
            make3(tr[u], mul);
        }
        return;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if(l <= mid) modify(u << 1, l, r, add, mul, change);
    if(r > mid) modify(u << 1 | 1, l, r, add, mul, change);
    pushup(u);    
}
int query(int u, int l, int r, int op){
    if(tr[u].l >= l && tr[u].r <= r){
        return tr[u].sum[op];
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1, res = 0;
    if(l <= mid) res = query(u << 1, l, r, op);
    if(r > mid) res = (res + query(u << 1 | 1, l, r, op)) % mod;
    pushup(u);
    return res;
}
signed main()
{
    std::ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
    int n, m; 
    while(cin >> n >> m && n != 0 && m != 0){
        build(1, 1, n);
        while(m--){
            int op, l, r, p; cin >> op >> l >> r >> p;
            if(op == 1) modify(1, l, r, p, 1, 0);
            else if(op == 2) modify(1, l, r, 0, p, 0);
            else if(op == 3) modify(1, l, r, 0, 1, p);
            else cout << query(1, l, r, p - 1) << '\n';
        }
    }     
    return 0;
}

区间最值操作

上述的操作基本都是对区间的每一个数都做出改变, 通过如此操作我们的总和也较容易维护, 但是如果出现这一种操作: 对于一个区间 \([L,R]\) 中的所有数做: \(a_i = \text{min}(a_i, x)\), 这样的操作也是可以通过线段树来进行 \(log\) 级别的维护的, 具体方法如下:
线段树上额外维护这些值: 区间的最大值 \(mx\), 区间的次大值 \(se\), 区间最大值的个数 \(num\), 区间和 \(sum\)
那么当我们要进行一个操作的时候, 如果有:

  • \(x >= mx\) 那么很明显不需要更新
  • \(se <= x < mx\), 那么此时符合条件的只有最大值, 我们直接把最大值改成 \(x\), 然后更新一下总和即可, \(sum = sum - num * (mx - x)\)
  • \(x < se\) 那么这个区间是更新不了, 递归左右儿子更新即可
    此时就完成了, 这么做的复杂度是 \(log\), 先不做证明了, 后面补
    这里给出一个非常经典的例题: 最假女选手:
    注意, 在更新最大和次大以及最小和次小的时候注意两者是否可能相等, 若相等还需要进一步更改
#include <bits/stdc++.h>
#define ls u << 1
#define rs u << 1 | 1
#define LL long long
#define inf 1000000000
using namespace std;
const int N = 5e5 + 10, mod = 1e9 + 7;
int a[N];
struct node{
    int l, r, tag1[2], tag2[2], num[2], vis[2], add;
    LL sum;
}tr[N << 2];

void make(int u, int x){
    tr[u].sum = tr[u].sum + 1LL * (tr[u].r - tr[u].l + 1) * x;
    tr[u].tag1[0] += x, tr[u].tag1[1] += x;
    tr[u].tag2[0] += x, tr[u].tag2[1] += x;
    tr[u].add += x;
}

void make1(int u, int x){
    if(x >= tr[u].tag1[0]) return;
    tr[u].sum = tr[u].sum - 1LL * tr[u].num[0] * (tr[u].tag1[0] - x);
    if(tr[u].tag2[0] == tr[u].tag1[0]) tr[u].tag2[0] = x;
    if(tr[u].tag2[1] == tr[u].tag1[0]) tr[u].tag2[1] = x;
    tr[u].tag1[0] = tr[u].vis[0] = x;
}
void make2(int u, int x){
    if(x <= tr[u].tag2[0]) return;
    tr[u].sum = tr[u].sum + 1LL * tr[u].num[1] * (x - tr[u].tag2[0]);
    if(tr[u].tag1[0] == tr[u].tag2[0]) tr[u].tag1[0] = x;
    if(tr[u].tag1[1] == tr[u].tag2[0]) tr[u].tag1[1] = x;
    tr[u].tag2[0] = tr[u].vis[1] = x;   
}
void pushup(int u){
    tr[u].sum = tr[ls].sum + tr[rs].sum;
    tr[u].tag1[0] = max(tr[ls].tag1[0], tr[rs].tag1[0]); 
    if(tr[ls].tag1[0] == tr[rs].tag1[0]){
        tr[u].tag1[1] = max(tr[ls].tag1[1], tr[rs].tag1[1]);
        tr[u].num[0] = tr[ls].num[0] + tr[rs].num[0];
    } 
    else{
        tr[u].tag1[1] = min(tr[ls].tag1[0], tr[rs].tag1[0]);
        tr[u].tag1[1] = max({tr[u].tag1[1], tr[ls].tag1[1], tr[rs].tag1[1]}); 
        tr[u].num[0] = tr[ls].tag1[0] > tr[rs].tag1[0] ? tr[ls].num[0] : tr[rs].num[0];
    }
    tr[u].tag2[0] = min(tr[ls].tag2[0], tr[rs].tag2[0]); 
    if(tr[ls].tag2[0] == tr[rs].tag2[0]){
        tr[u].tag2[1] = min(tr[ls].tag2[1], tr[rs].tag2[1]);
        tr[u].num[1] = tr[ls].num[1] + tr[rs].num[1];
    } 
    else{
        tr[u].tag2[1] = max(tr[ls].tag2[0], tr[rs].tag2[0]);
        tr[u].tag2[1] = min({tr[u].tag2[1], tr[ls].tag2[1], tr[rs].tag2[1]}); 
        tr[u].num[1] = tr[ls].tag2[0] < tr[rs].tag2[0] ? tr[ls].num[1] : tr[rs].num[1];
    }
}

void pushdown(int u){
    if(tr[u].add != 0){
        make(ls, tr[u].add), make(rs, tr[u].add);
        tr[u].add = 0;
    }
    make1(ls, tr[u].tag1[0]), make1(rs, tr[u].tag1[0]);
    make2(ls, tr[u].tag2[0]), make2(rs, tr[u].tag2[0]);
}

void build(int u, int l, int r){
    if(l == r){
        tr[u] = {l, r, {a[l], -inf}, {a[l], inf}, {1, 1}, {-inf, inf}, 0, 1LL * a[l]};
        return;
    }
    tr[u] = {l, r, {-inf, -inf}, {inf, inf}, {0, 0}, {-inf, inf}, 0, 1LL * 0};
    int mid = l + r >> 1;
    build(ls, l, mid), build(rs, mid+1 ,r);
    pushup(u);
}

void modifyadd(int u, int l, int r, int x){
    if(tr[u].l >= l && tr[u].r <= r){
        make(u, x);
        return;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if(l <= mid) modifyadd(ls, l, r, x);
    if(r > mid) modifyadd(rs, l, r, x);
    pushup(u);
}

void modifymin(int u, int l, int r, int x){
    if(tr[u].tag1[0] <= x) return;
    if(tr[u].l >= l && tr[u].r <= r && x > tr[u].tag1[1]){
        make1(u, x);
        return;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if(l <= mid) modifymin(ls, l, r, x);
    if(r > mid) modifymin(rs, l, r, x);
    pushup(u);
}

void modifymax(int u, int l, int r, int x){
    if(tr[u].tag2[0] >= x) return;
    if(tr[u].l >= l && tr[u].r <= r && x < tr[u].tag2[1]){
        make2(u, x);
        return;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if(l <= mid) modifymax(ls, l, r, x);
    if(r > mid) modifymax(rs, l, r, x);
    pushup(u);
}

LL querysum(int u, int l, int r){
    if(tr[u].l >= l && tr[u].r <= r){
        return tr[u].sum;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    LL res = 0;
    if(l <= mid) res += querysum(ls, l, r);
    if(r > mid) res += querysum(rs, l, r);
    return res;
}

int querymax(int u, int l, int r){
    if(tr[u].l >= l && tr[u].r <= r){
        return tr[u].tag1[0];
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1, res = -inf;
    if(l <= mid) res = max(res, querymax(ls, l, r));
    if(r > mid) res = max(res, querymax(rs, l, r));
    return res;
}

int querymin(int u, int l, int r){
    if(tr[u].l >= l && tr[u].r <= r){
        return tr[u].tag2[0];
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1, res = inf;
    if(l <= mid) res = min(res, querymin(ls, l, r));
    if(r > mid) res = min(res, querymin(rs, l, r));
    return res;
}

void debug(int i){
    cout << "----now----" << endl;
    cout << tr[i].l << " " << tr[i].r << endl;
    cout << tr[i].tag1[0] << " " << tr[i].tag1[1] << endl;
    cout << tr[i].tag2[0] << " " << tr[i].tag2[1] << endl;
    cout << tr[i].sum << endl;
    cout << "----end----" << endl;
}

signed main()
{
    // std::ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
    int n; cin >> n;
    for(int i = 1; i <= n; i++) cin >> a[i];
    build(1, 1, n);
    int m; cin >> m;
    while(m--){
        int op, l, r; cin >> op >> l >> r;
        if(op == 1){
            int x; cin >> x;
            modifyadd(1, l, r, x);
        }
        else if(op == 2){
            int x; cin >> x;
            modifymax(1, l, r, x);
        }
        else if(op == 3){
            int x; cin >> x;
            modifymin(1, l, r, x);
        }
        else if(op == 4) printf("%lld\n", querysum(1, l, r));
        else if(op == 5) printf("%d\n", querymax(1, l, r));
        else if(op == 6) printf("%d\n", querymin(1, l, r));
    }
    return 0;
}

主席树/可持久化线段树

旨在记录之前操作的状态, 其方法很好理解, 可以发现我们在进行一个单点更新的时候, 其实最多只会经过 \(log_n\) 的节点, 那么我们只需要向下图一样, 新建我们要用到的节点就可以了, 类似于动态开点一样, 然后我们每次记录不同的版本, 每个版本的入口就是根节点也就是 \(1'\), 其余的节点我们按照初始状态设置, 例如我们只需要经过左儿子, 那么我就不需要新建一个右儿子, 用原来的就可以, 形式化的就是 tr[new] = tr[old], 然后修改左儿子即可

【模板】可持久化线段树 1(可持久化数组)

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1e6 + 10, mod = 1e9 + 7;

struct Seg{
    int l, r, val;
}tr[N * 10];
int cnt;
int root[N], last[N], a[N];
void build(int u, int l, int r){
    if(l == r){
        tr[u].val = a[l];
        return;
    }
    tr[u].l = ++cnt, tr[u].r = ++cnt;
    int mid = l + r >> 1;
    build(tr[u].l, l, mid), build(tr[u].r, mid + 1, r);
}
void modify(int &u, int old, int l, int r, int x, int v){
    u = ++cnt;
    tr[u] = tr[old], tr[u].val = v;
    if(l == r) return;
    int mid = l + r >> 1;
    if(x <= mid) modify(tr[u].l, tr[old].l, l, mid, x, v);
    else modify(tr[u].r, tr[old].r, mid + 1, r, x, v);
}
int query(int u, int pl, int pr, int l, int r){
    if(pl >= l && pr <= r){
        return tr[u].val;
    }
    int mid = pl + pr >> 1, res = 0;
    if(l <= mid) res = query(tr[u].l, pl, mid, l, r);
    if(r > mid) res = query(tr[u].r, mid + 1, pr, l, r);
    return res;
}
signed main()
{
    std::ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
    int n, q; cin >> n >> q;

    for(int i = 1; i <= n; i++) cin >> a[i];
    
    root[0] = ++cnt;
    build(root[0], 1, n);

    for(int i = 1; i <= q; i++){
        int v, op, loc; cin >> v >> op >> loc;
        root[i] = root[v];
        if(op == 1){
            int x; cin >> x;
            modify(root[i], root[v], 1, n, loc, x);
        } else{
            cout << query(root[v], 1, n, loc, loc) << '\n';
        }
    }

    return 0;
}

ABC253 - F
由于操作二可以把之前所有的操作一的影响消除掉, 所以我们使用主席树来维护操作时间状态, 每次查询时, 我们知道上一次进行操作二的状last[x]以及更改的值a[x], 那么结果就是a[x]加上从上次操作二之后的操作一的和, 用主席树相减可以抵消

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1e6 + 10, mod = 1e9 + 7;

struct Seg{
    int l, r, val;
}tr[N << 2];
int cnt;
int root[N], last[N], a[N];
void build(int u, int l, int r){
    if(l == r) return;
    tr[u].l = ++cnt, tr[u].r = ++cnt;
    int mid = l + r >> 1;
    build(tr[u].l, l, mid), build(tr[u].r, mid + 1, r);
}
void modify(int &u, int old, int l, int r, int x, int v){
    u = ++cnt;
    tr[u] = tr[old], tr[u].val += v;
    if(l == r) return;
    int mid = l + r >> 1;
    if(x <= mid) modify(tr[u].l, tr[old].l, l, mid, x, v);
    else modify(tr[u].r, tr[old].r, mid + 1, r, x, v);
}
int query(int u, int old, int pl, int pr, int l, int r){
    if(pl >= l && pr <= r){
        return tr[u].val - tr[old].val;
    }
    int mid = pl + pr >> 1, res = 0;
    if(l <= mid) res += query(tr[u].l, tr[old].l, pl, mid, l, r);
    if(r > mid) res += query(tr[u].r, tr[old].r, mid + 1, pr, l, r);
    return res;
}
signed main()
{
    std::ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
    int n, m, q; cin >> n >> m >> q;

    root[0] = ++cnt;
    build(root[0], 1, m);

    for(int i = 1; i <= q; i++){
        int op; cin >> op;
        root[i] = root[i - 1];
        if(op == 1){
            int l, r, x; cin >> l >> r >> x;
            modify(root[i], root[i], 1, m, l, x);
            if(r + 1 <= m){
                modify(root[i], root[i], 1, m, r + 1, -x);
            }
        } else if(op == 2){
            int row, x; cin >> row >> x;
            last[row] = i, a[row] = x;            
        } else{
            int x, y; cin >> x >> y;
            cout << a[x] + query(root[i], root[last[x]], 1, m, 1, y) << '\n';
        }
    }

    return 0;
}
posted @ 2024-05-27 23:58  o-Sakurajimamai-o  阅读(22)  评论(0编辑  收藏  举报
-- --