线段树区间取膜438D - The Child and Sequence
438D - The Child and Sequence
题面
长度为n的非负整数数列,3种操作
- 求\([L,R]\)所有数的和。
- 将\([L,R]\)中所有数都\(mod \ x\)。
- 将\(a_i\)修改为\(v\)。
\(n,m≤100000\)
题解
对于一段区间,如果取模的数比这段区间所有的数都大,那取模就是没有意义的,就是说,如果取模的数比区间最大的数还大,那么就不用取模了,所以我们在线段树里再记录一个区间最大值
考虑每次取模,对于每一个数\(x\),取模\(y\),\(x \ mod \ y\)的值必然比\(y\)小,如果\(y\)小于\(\frac{x}{2}\),那x就变得小于\(\frac{x}{2}\),如果\(y\)大于\(\frac{x}{2}\),\(x\)剩下的部分也比\(\frac{x}{2}\)少,\(x\)也会变得比\(\frac{x}{2}\)小
那么就是说\(x\)每次取模都会变得比\(\frac{x}{2}\)小,就是说,对于一个数\(x\),有效的取\(mod\)最多进行\(logx\)次,一共只会进行\(nlogn\)次取模,那么就算对所有的数取模,这个时间复杂度都是可以接受的
那么我们对于每个区间记录一个最大值,如果取模的数大于最大值,就不管,如果小于最大值,就暴力取模
代码
struct BIT {
struct node {
int l, r;
long long sum;
int mx;
} tr[N << 2];
void push_up(int p) {
tr[p].mx = max(tr[p << 1].mx, tr[p << 1 | 1].mx);
tr[p].sum = tr[p << 1].sum + tr[p << 1 | 1].sum;
}
void build(int p, int l, int r, vector<int>& a) {
tr[p].l = l, tr[p].r = r;
if (l == r) {
tr[p].sum = tr[p].mx = a[l];
return;
}
int mid = l + r >> 1;
build(p << 1, l, mid, a);
build(p << 1 | 1, mid + 1, r, a);
push_up(p);
}
int ask(int p, int l, int r) {
if (tr[p].l >= l && tr[p].r <= r)
return tr[p].sum;
int mid = tr[p].l + tr[p].r >> 1;
if (mid >= r)
return ask(p << 1, l, r);
if (mid < l)
return ask(p << 1 | 1, l, r);
return ask(p << 1, l, r) + ask(p << 1 | 1, l, r);
}
void change(int p, int k, int v) {
if (tr[p].l == k && tr[p].r == k) {
tr[p].sum = tr[p].mx = v;
return;
}
int mid = tr[p].l + tr[p].r >> 1;
if (mid >= k)
change(p << 1, k, v);
else if (mid < k)
change(p << 1 | 1, k, v);
push_up(p);
}
void changemod(int p, int l, int r, int mod) {
if (tr[p].l == tr[p].r) {
tr[p].sum = tr[p].mx = tr[p].sum % mod;
return;
}
int mid = tr[p].l + tr[p].r >> 1;
if (mid >= l && tr[p << 1].mx >= mod)
changemod(p << 1, l, r, mod);
if (mid < r && tr[p << 1 | 1].mx >= mod)
changemod(p << 1 | 1, l, r, mod);
push_up(p);
}
} bit;
int main() {
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
int n, m;
cin >> n >> m;
vector<int> a(n + 1);
for (int i = 1; i <= n; ++i)
cin >> a[i];
bit.build(1, 1, n, a);
while (m--) {
int op;
cin >> op;
if (op == 1) {
int l, r;
cin >> l >> r;
cout << bit.ask(1, l, r) << '\n';
} else if (op == 2) {
int l, r, mod;
cin >> l >> r >> mod;
bit.changemod(1, l, r, mod);
} else {
int k, v;
cin >> k >> v;
bit.change(1, k, v);
}
}
return 0;
}