如何用线段树维护一些数学公式
1. 维护等差数列
例1:洛谷 P1438 无聊的数列(插入等差数列,单点查询)
这题有两个做法,第一个做法是用线段树维护等差数列,不过这里不多赘述,在下一个例子再详细介绍;第二个做法是用线段树维护差分数组,把单点查询转化为查询前缀和。
#include <bits/stdc++.h>
using namespace std;
const int N = 100005;
typedef long long ll;
int n, m;
int a[N];
struct Seg_Tree {
struct node {
int l, r;
ll s, tag;
} seg[N * 4];
void pushup(int k)
{
seg[k].s = seg[k << 1].s + seg[k << 1 | 1].s;
}
void set_tag(int k, ll add)
{
int len = seg[k].r - seg[k].l + 1;
seg[k].s += 1ll * add * len;
seg[k].tag += add;
}
void pushdown(int k)
{
if(seg[k].tag != 0)
{
set_tag(k << 1, seg[k].tag);
set_tag(k << 1 | 1, seg[k].tag);
seg[k].tag = 0;
}
}
void build(int k, int l, int r)
{
seg[k] = {l, r, 0, 0};
if(l == r) {
seg[k].s = a[l];
return;
}
int mid = (l + r) >> 1;
build(k << 1, l, mid);
build(k << 1 | 1, mid + 1, r);
pushup(k);
}
void modify(int k, int ql, int qr, ll v)
{
int l = seg[k].l, r = seg[k].r;
if(ql <= l && r <= qr)
{
set_tag(k, v);
return;
}
pushdown(k);
int mid = (l + r) >> 1;
if(ql <= mid)
modify(k << 1, l, mid, ql, qr, v);
if(mid < qr)
modify(k << 1 | 1, mid + 1, r, ql, qr, v);
pushup(k);
}
ll query(int k, int ql, int qr)
{
int l = seg[k].l, r = seg[k].r;
if(ql <= l && r <= qr)
return seg[k].s;
pushdown(k);
int mid = (l + r) >> 1;
ll res = 0;
if(ql <= mid)
res += query(k << 1, l, mid, ql, qr);
if(mid < qr)
res += query(k << 1 | 1, mid + 1, r, ql, qr);
return res;
}
};
Seg_Tree T1;
int main()
{
cin >> n >> m;
for(int i = 1; i <= n; i ++)
cin >> a[i];
for(int i = n; i >= 1; i --)
a[i] = a[i] - a[i - 1];
T1.build(1, 1, n);
while(m --)
{
int opt;
cin >> opt;
if(opt == 1)
{
int l, r, k, d;
cin >> l >> r >> k >> d;
T1.modify(1, l, l, k);
if(l != r)
T1.modify(1, l + 1, r, d);
if(r + 1 <= n)
T1.modify(1, r + 1, r + 1, (l - r) * d - k);
}
else
{
int p;
cin >> p;
cout << T1.query(1, 1, p) << '\n';
}
}
}
例2:牛牛的等差数列(插入等差数列,区间查询)
在这道题就详细地谈一下是怎么用线段树来维护等差数列的。
维护线段的区间和,懒标记分别为首项 \(x\) 和公差 \(d\) ,然后根据传到该区间的的首项和公差,再结合区间长度,通过等差数列求和公式计算出等差数列的和
又因为是区间修改,所以我们是需要懒标记的。懒标记中需要含有我们计算和的关键:首项和公差。我们可以思考如何对一段区间进行修改:若整个区间都在 \(mid\) 的左边或者整个区间都在 \(mid\) 的右边,那么首项和公差直接累加上去就行了;若区间被 \(mid\) 分割为两段,把等差数列分成两段,当然右边的首项跟左边的首项是不一样的。
那我们下传懒标记的时候,把懒标记的贡献加到区间和里,然后再累加懒标记。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 200005;
const ll mod = 111546435, inv2 = 55773218; //mod是lcm(3,25),inv2是2在mod下逆元
int n, q;
ll a[N];
ll calc(ll k, ll d, ll len)
{
return (1ll * len * k % mod + len * (len - 1) % mod * d % mod * inv2 % mod) % mod;
}
struct Seg_tree {
struct node {
int l, r;
ll v;
ll kk, d; //首项,公差
} seg[N << 2];
void pushup(int k)
{
seg[k].v = (seg[k << 1].v + seg[k << 1 | 1].v) % mod;
}
void build(int k, int l, int r)
{
seg[k] = {l, r, 0, 0, 0};
if(l == r) {
seg[k].v = a[l] % mod;
return;
}
int mid = (l + r) >> 1;
build(k << 1, l, mid);
build(k << 1 | 1, mid + 1, r);
pushup(k);
}
void pushdown(int k)
{
auto &l = seg[k << 1], &r = seg[k << 1 | 1];
auto &kk = seg[k].kk, &d = seg[k].d;
if(kk && d)
{
ll lenl = (l.r - l.l + 1), lenr = (r.r - r.l + 1);
(l.kk += kk) %= mod;
(l.d += d) %= mod;
(l.v += calc(kk, d, lenl)) %= mod;
ll kk2 = (kk + d * lenl % mod) % mod;
(r.kk += kk2) %= mod;
(r.d += d) %= mod;
(r.v += calc(kk2, d, lenr)) %= mod;
kk = d = 0;
}
}
void modify(int k, int ql, int qr, ll kk, ll d)
{
int l = seg[k].l, r = seg[k].r;
if(ql <= l && r <= qr) {
(seg[k].v += calc(kk, d, 1ll * (r - l + 1))) %= mod;
(seg[k].kk += kk) %= mod;
(seg[k].d += d) %= mod;
return;
}
pushdown(k);
int mid = (l + r) >> 1;
if(qr <= mid)
modify(k << 1, ql, qr, kk, d);
else if(ql > mid)
modify(k << 1 | 1, ql, qr, kk, d);
else
{
modify(k << 1, ql, mid, kk, d);
(kk += (d * 1ll * (mid - ql + 1) % mod)) %= mod;
modify(k << 1 | 1, mid + 1, qr, kk, d);
}
pushup(k);
}
ll query(int k, int ql, int qr)
{
int l = seg[k].l, r = seg[k].r;
if(ql <= l && r <= qr)
return seg[k].v;
pushdown(k);
int mid = (l + r) >> 1;
ll res = 0;
if(ql <= mid)
res += query(k << 1, ql, qr), res %= mod;
if(mid < qr)
res += query(k << 1 | 1, ql, qr), res %= mod;
return res;
}
};
Seg_tree T1;
signed main()
{
ios::sync_with_stdio(false);cin.tie(0);
cin >> n;
for(int i = 1; i <= n; i ++)
cin >> a[i];
T1.build(1, 1, n);
cin >> q;
while(q --)
{
int opt;
cin >> opt;
int l, r;
cin >> l >> r;
if(opt == 1)
{
ll kk, d;
cin >> kk >> d;
T1.modify(1, l, r, kk % mod, d % mod);
}
else
{
int m;
cin >> m;
cout << T1.query(1, l, r) % m << '\n';
}
}
}
2. 维护二次函数
例1. 智乃酱的平方数列
我们也可以用线段树来维护二次函数。
我们考虑对于 \([l,r]\) 添加平方数列,对于位置 \(\mathrm{x} \in [l, \mathrm{r}]\),增加的值应当是 \((x - (l - 1))^2\)。
展开后为: \(x^2+2(l-1)x+(l-1)^2\) 。那么这就是一个二次函数了,我们需要维护其系数。
我们需要三个懒标记,分别维护的是二次项 \(x^2\)的系数和,一次项 \(x\) 的系数和以及常数项的系数和。
那么若要维护一次函数也同理。
#include <bits/stdc++.h>
using namespace std;
const int N = 500005;
typedef long long ll;
const ll mod = 1000000007;
int n, m;
int a[N];
struct Seg_Tree {
struct node {
int l, r;
ll s;
ll base1, base2;
ll lazy1, lazy2, lazy3;
} seg[N * 4];
void pushup(int k)
{
seg[k].s = (seg[k << 1].s + seg[k << 1 | 1].s) % mod;
}
void set_tag(int k, ll lazy1, ll lazy2, ll lazy3)
{
int len = seg[k].r - seg[k].l + 1;
(seg[k].s += (seg[k].base1 * lazy1) % mod) %= mod;
(seg[k].s += ((-2ll * seg[k].base2 * lazy2 % mod) + mod) % mod) %= mod;
(seg[k].s += (1ll * len * lazy3 % mod)) %= mod;
(seg[k].lazy1 += lazy1) %= mod;
(seg[k].lazy2 += lazy2) %= mod;
(seg[k].lazy3 += lazy3) %= mod;
}
void pushdown(int k)
{
if(seg[k].lazy1 || seg[k].lazy2 || seg[k].lazy3)
{
set_tag(k << 1, seg[k].lazy1, seg[k].lazy2, seg[k].lazy3);
set_tag(k << 1 | 1, seg[k].lazy1, seg[k].lazy2, seg[k].lazy3);
seg[k].lazy1 = seg[k].lazy2 = seg[k].lazy3 = 0;
}
}
void build(int k, int l, int r)
{
seg[k].l = l, seg[k].r = r;
if(l == r) {
seg[k].s = a[l];
seg[k].base1 = 1ll * l * l % mod;
seg[k].base2 = l % mod;
return;
}
int mid = (l + r) >> 1;
build(k << 1, l, mid);
build(k << 1 | 1, mid + 1, r);
seg[k].base1 = (seg[k << 1].base1 + seg[k << 1 | 1].base1) % mod;
seg[k].base2 = (seg[k << 1].base2 + seg[k << 1 | 1].base2) % mod;
pushup(k);
}
void modify(int k,int ql, int qr, ll v)
{
int l = seg[k].l, r = seg[k].r;
if(ql <= l && r <= qr)
{
set_tag(k, 1, v, v * v % mod);
return;
}
pushdown(k);
int mid = (l + r) >> 1;
if(ql <= mid)
modify(k << 1, ql, qr, v);
if(mid < qr)
modify(k << 1 | 1, ql, qr, v);
pushup(k);
}
ll query(int k, int ql, int qr)
{
int l = seg[k].l, r = seg[k].r;
if(ql <= l && r <= qr)
return seg[k].s;
pushdown(k);
int mid = (l + r) >> 1;
ll res = 0;
if(ql <= mid)
res += query(k << 1, ql, qr), res %= mod;
if(mid < qr)
res += query(k << 1 | 1, ql, qr), res %= mod;
return res;
}
};
Seg_Tree T1;
int main()
{
ios::sync_with_stdio(false);cin.tie(0);
cin >> n >> m;
T1.build(1, 1, n);
while(m --)
{
int opt, l, r;
cin >> opt;
cin >> l >> r;
if(opt == 1)
T1.modify(1, l, r, l - 1);
else
cout << T1.query(1, l, r) << '\n';
}
}