线段树选做
[NOIP2012]借教室
简单的区间加和区间查询
#include <bits/stdc++.h>
using namespace std;
using i32 = int32_t;
using i64 = long long;
using ldb = long double;
//#define int i64
using vi = vector<int>;
using pii = pair<int, int>;
const int mod = 1e9 + 7;
struct Node {
int l, r, value, add;
Node *left, *right;
Node(int l, int r, int value, Node *left, Node *right)
: l(l), r(r), value(value), left(left), right(right) {
add = 0;
}
} *root;
Node *build(int l, int r, const vector<int> &v) {
if (l == r) {
return new Node(l, r, v[l], nullptr, nullptr);
}
int mid = (l + r) / 2;
Node *left = build(l, mid, v);
Node *right = build(mid + 1, r, v);
return new Node(l, r, min(left->value, right->value), left, right);
}
void mark(int v, Node *cur) {
cur->add += v;
cur->value += v;
return;
}
void pushdown(Node *cur) {
if (cur->add == 0) return;
mark(cur->add, cur->left);
mark(cur->add, cur->right);
cur->add = 0;
return;
}
void modify(int l, int r, int v, Node *cur) {
if (l > cur->r or r < cur->l) return;
if (l <= cur->l and cur->r <= r) {
mark(v, cur);
return;
}
pushdown(cur);
int mid = (cur->l + cur->r) / 2;
if (l <= mid) modify(l, r, v, cur->left);
if (r > mid) modify(l, r, v, cur->right);
cur->value = min(cur->left->value, cur->right->value);
return;
}
int query(int l, int r, Node *cur) {
if (l > cur->r or r < cur->l) return INT_MAX;
if (l <= cur->l and cur->r <= r) return cur->value;
pushdown(cur);
int mid = (cur->l + cur->r) / 2, res = INT_MAX;
if (l <= mid) res = min(res, query(l, r, cur->left));
if (r > mid) res = min(res, query(l, r, cur->right));
return res;
}
i32 main() {
ios::sync_with_stdio(false), cin.tie(nullptr);
int n, m;
cin >> n >> m;
vi a(n + 1);
for (int i = 1; i <= n; i++)
cin >> a[i];
root = build(1, n, a);
vi res;
for (int i = 1, l, r, v; i <= m; i++) {
cin >> v >> l >> r;
if (query(l, r, root) < v) {
cout << "-1\n" << i << "\n";
return 0;
}
modify(l, r, -v, root);
}
cout << "0\n";
return 0;
}
数据结构
https://ac.nowcoder.com/acm/problem/19246
维护复杂了一点
#include <bits/stdc++.h>
using namespace std;
using i32 = int32_t;
using i64 = long long;
using ldb = long double;
#define int i64
using vi = vector<int>;
struct Node {
int l, r, value, square, add, mul;
Node *left, *right;
Node(int l, int r, int value, int square, Node *left, Node *right)
: l(l), r(r), value(value), square(square), left(left), right(right) {
add = 0, mul = 1;
}
} *root;
Node *build(int l, int r, const vector<int> &v) {
if (l == r) {
return new Node(l, r, v[l], v[l] * v[l], nullptr, nullptr);
}
int mid = (l + r) / 2;
Node *left = build(l, mid, v);
Node *right = build(mid + 1, r, v);
return new Node(l, r, left->value + right->value, left->square + right->square, left, right);
}
void markMul(int v, Node *cur) {
cur->mul *= v;
cur->add *= v;
cur->value *= v;
cur->square *= v * v;
return;
}
void markAdd(int v, Node *cur) {
cur->square += (cur->r - cur->l + 1) * v * v + 2 * v * cur->value;
cur->value += v * (cur->r - cur->l + 1);
cur->add += v;
return;
}
void pushdown(Node *cur) {
if (cur->mul != 1) {
markMul(cur->mul, cur->left);
markMul(cur->mul, cur->right);
cur->mul = 1;
}
if (cur->add != 0) {
markAdd(cur->add, cur->left);
markAdd(cur->add, cur->right);
cur->add = 0;
}
return;
}
int calcValue(int l, int r, Node *cur) {
if (l > cur->r or r < cur->l) return 0;
if (l <= cur->l and cur->r <= r) return cur->value;
pushdown(cur);
int mid = (cur->l + cur->r) / 2, sum = 0;
if (l <= mid) sum += calcValue(l, r, cur->left);
if (r > mid) sum += calcValue(l, r, cur->right);
return sum;
}
int calcSquare(int l, int r, Node *cur) {
if (l > cur->r or r < cur->l) return 0;
if (l <= cur->l and cur->r <= r) return cur->square;
pushdown(cur);
int mid = (cur->l + cur->r) / 2, sum = 0;
if (l <= mid) sum += calcSquare(l, r, cur->left);
if (r > mid) sum += calcSquare(l, r, cur->right);
return sum;
}
void modifyMul(int l, int r, int v, Node *cur) {
if (l > cur->r or r < cur->l) return;
if (l <= cur->l and cur->r <= r) {
markMul(v, cur);
return;
}
pushdown(cur);
int mid = (cur->l + cur->r) / 2;
if (l <= mid) modifyMul(l, r, v, cur->left);
if (r > mid) modifyMul(l, r, v, cur->right);
cur->value = cur->left->value + cur->right->value;
cur->square = cur->left->square + cur->right->square;
return;
}
void modifyAdd(int l, int r, int v, Node *cur) {
if (l > cur->r or r < cur->l) return;
if (l <= cur->l and cur->r <= r) {
markAdd(v, cur);
return;
}
pushdown(cur);
int mid = (cur->l + cur->r) / 2;
if (l <= mid) modifyAdd(l, r, v, cur->left);
if (r > mid) modifyAdd(l, r, v, cur->right);
cur->value = cur->left->value + cur->right->value;
cur->square = cur->left->square + cur->right->square;
return;
}
i32 main() {
ios::sync_with_stdio(false), cin.tie(nullptr);
int n, m;
cin >> n >> m;
vector<int> a(n + 1);
for (int i = 1; i <= n; i++)
cin >> a[i];
root = build(1, n, a);
for (int opt, l, r, v; m; m--) {
cin >> opt >> l >> r;
if (opt == 1) {
cout << calcValue(l, r, root) << "\n";
} else if (opt == 2) {
cout << calcSquare(l, r, root) << "\n";
} else if (opt == 3) {
cin >> v;
modifyMul(l, r, v, root);
} else {
cin >> v;
modifyAdd(l, r, v, root);
}
}
return 0;
}
线段树
https://ac.nowcoder.com/acm/problem/212880
两两乘积和不好维护,可以用区间和、区间平方和导出。
#include <bits/stdc++.h>
using namespace std;
using i32 = int32_t;
using i64 = long long;
using ldb = long double;
#define int i64
int mod;
int power(int x, int y) {
int ans = 1;
while (y) {
if (y & 1) ans = ans * x % mod;
x = x * x % mod, y /= 2;
}
return ans;
}
int inv(int x) {
return power(x, mod - 2);
}
using vi = vector<int>;
struct Node {
int l, r, value, square, add, mul;
Node *left, *right;
Node(int l, int r, int value, int square, Node *left, Node *right)
: l(l), r(r), value(value), square(square), left(left), right(right) {
add = 0, mul = 1;
}
} *root;
Node *build(int l, int r, const vector<int> &v) {
if (l == r) {
return new Node(l, r, v[l] % mod, v[l] * v[l] % mod, nullptr, nullptr);
}
int mid = (l + r) / 2;
Node *left = build(l, mid, v);
Node *right = build(mid + 1, r, v);
return new Node(l, r, (left->value + right->value) % mod, (left->square + right->square) % mod, left, right);
}
void markMul(int v, Node *cur) {
(cur->mul *= v) %= mod;
(cur->add *= v) %= mod;
(cur->value *= v) %= mod;
(cur->square *= v * v % mod) %= mod;
return;
}
void markAdd(int v, Node *cur) {
(cur->square += (cur->r - cur->l + 1) * v % mod * v % mod + 2 * v % mod * cur->value % mod) %= mod;
(cur->value += v * (cur->r - cur->l + 1) % mod) %= mod;
(cur->add += v) %= mod;
return;
}
void pushdown(Node *cur) {
if (cur->mul != 1) {
markMul(cur->mul, cur->left);
markMul(cur->mul, cur->right);
cur->mul = 1;
}
if (cur->add != 0) {
markAdd(cur->add, cur->left);
markAdd(cur->add, cur->right);
cur->add = 0;
}
return;
}
int calcValue(int l, int r, Node *cur) {
if (l > cur->r or r < cur->l) return 0;
if (l <= cur->l and cur->r <= r) return cur->value;
pushdown(cur);
int mid = (cur->l + cur->r) / 2, sum = 0;
if (l <= mid) (sum += calcValue(l, r, cur->left)) %= mod;
if (r > mid) (sum += calcValue(l, r, cur->right)) %= mod;
return sum;
}
int calcSquare(int l, int r, Node *cur) {
if (l > cur->r or r < cur->l) return 0;
if (l <= cur->l and cur->r <= r) return cur->square;
pushdown(cur);
int mid = (cur->l + cur->r) / 2, sum = 0;
if (l <= mid) (sum += calcSquare(l, r, cur->left)) %= mod;
if (r > mid) (sum += calcSquare(l, r, cur->right)) %= mod;
return sum;
}
void modifyMul(int l, int r, int v, Node *cur) {
if (l > cur->r or r < cur->l) return;
if (l <= cur->l and cur->r <= r) {
markMul(v, cur);
return;
}
pushdown(cur);
int mid = (cur->l + cur->r) / 2;
if (l <= mid) modifyMul(l, r, v, cur->left);
if (r > mid) modifyMul(l, r, v, cur->right);
cur->value = (cur->left->value + cur->right->value) % mod;
cur->square = (cur->left->square + cur->right->square) % mod;
return;
}
void modifyAdd(int l, int r, int v, Node *cur) {
if (l > cur->r or r < cur->l) return;
if (l <= cur->l and cur->r <= r) {
markAdd(v, cur);
return;
}
pushdown(cur);
int mid = (cur->l + cur->r) / 2;
if (l <= mid) modifyAdd(l, r, v, cur->left);
if (r > mid) modifyAdd(l, r, v, cur->right);
cur->value = (cur->left->value + cur->right->value) % mod;
cur->square = (cur->left->square + cur->right->square) % mod;
return;
}
void solve(){
int n, m;
cin >> n >> m >> mod;
vector<int> a(n + 1);
for (int i = 1; i <= n; i++)
cin >> a[i];
root = build(1, n, a);
for (int opt, l, r, v; m; m--) {
cin >> opt >> l >> r;
if (opt == 1) {
cin >> v;
modifyAdd(l, r, v, root);
} else if (opt == 2) {
cin >> v;
modifyMul(l, r, v, root);
} else {
int x = calcValue(l, r, root), y = calcSquare(l, r, root);
cout << (x * x % mod - y % mod + mod) % mod * inv(2) % mod << "\n";
}
}
}
i32 main() {
ios::sync_with_stdio(false), cin.tie(nullptr);
int TC;
for( cin >> TC; TC; TC --)
solve();
return 0;
}
仓鼠的鸡蛋
https://ac.nowcoder.com/acm/problem/226170
在线段树上二分查找,并返回坐标。
#include <bits/stdc++.h>
using namespace std;
using i32 = int32_t;
using i64 = long long;
const int inf = INT_MAX / 2;
using vi = vector<int>;
using pii = pair<int, int>;
struct Node {
int l, r, val, cnt;
Node *left, *right;
Node(int l, int r, int val, int cnt, Node *left = nullptr, Node *right = nullptr)
: l(l), r(r), val(val), cnt(cnt), left(left), right(right) {};
} *root;
Node *build(int l, int r, int m, int k) {
if (l == r) return new Node(l, r, m, k);
int mid = (l + r) / 2;
Node *left = build(l, mid, m, k);
Node *right = build(mid + 1, r, m, k);
return new Node(l, r, m, 0, left, right);
}
void modify(int p, int delta, Node *cur) {
if (p < cur->l or cur->r < p) return;
if (cur->l == cur->r) {
cur->val += delta;
cur->cnt--;
if (cur->cnt == 0) cur->val = -inf;
return;
}
int mid = (cur->l + cur->r) / 2;
if (p <= mid) modify(p, delta, cur->left);
if (p > mid) modify(p, delta, cur->right);
cur->val = max(cur->left->val, cur->right->val);
return;
}
int query(int x, Node *cur) {
if (cur->l == cur->r) return cur->l;
if (cur->left->val >= x) return query(x, cur->left);
if (cur->right->val >= x) return query(x, cur->right);
return -1;
}
void solve() {
int n, m, k;
cin >> n >> m >> k;
root = build(1, n, m, k);
for (int x, l, r, ans, i = 1; i <= n; i++) {
cin >> x;
ans = query(x, root);
cout << ans << "\n";
if (ans != -1) modify(ans, -x, root);
}
return;
}
int main() {
ios::sync_with_stdio(false), cin.tie(nullptr);
int TC;
for (cin >> TC; TC; TC--)
solve();
return 0;
}
P5490 【模板】扫描线
https://www.luogu.com.cn/problem/P5490
扫描线模板题
// luogu P5490
#include<bits/stdc++.h>
using namespace std;
using i32 = int32_t;
using i64 = long long;
#define int i64
const int inf = INT_MAX / 2;
using seg = array<int, 4>; // {x, y1, y2, v}
using vi = vector<int>;
struct Node {
int l, r, value, cnt;// value 被覆盖子区间长度,cnt 当前区间被完整覆盖的次数。
Node *left, *right;
Node(int l, int r, int value, int cnt, Node *left, Node *right)
: l(l), r(r), value(value), cnt(cnt), left(left), right(right) {};
} *root;
vi raw;// 原始坐标
int val(int x) { // // x 是原始值,返回哈希值
int i = lower_bound(raw.begin(), raw.end(), x) - raw.begin();
if (raw[i] != x) return -1;
return i;
}
Node *build(int l, int r) {
if (l == r) return new Node(l, r, 0, 0, nullptr, nullptr);
int mid = (l + r) / 2;
auto left = build(l, mid), right = build(mid + 1, r);
return new Node(l, r, 0, 0, left, right);
}
void maintain(Node *cur) {
if (cur->cnt > 0) cur->value = raw[cur->r + 1] - raw[cur->l];
else if (cur->left == nullptr) cur->value = 0;
else cur->value = cur->left->value + cur->right->value;
return;
}
void modify(int l, int r, int v, Node *cur) {
if (l > cur->r or r < cur->l) return;
if (l <= cur->l and cur->r <= r) {
cur->cnt += v;
maintain(cur);
return;
}
int mid = (cur->l + cur->r) / 2;
if (l <= mid) modify(l, r, v, cur->left);
if (r > mid) modify(l, r, v, cur->right);
maintain(cur);
return;
}
i32 main() {
ios::sync_with_stdio(false), cin.tie(nullptr);
int n;
cin >> n;
vector<seg> p;
for (int i = 1, xa, xb, ya, yb; i <= n; i++) {
cin >> xa >> ya >> xb >> yb;
p.push_back(seg{xa, ya, yb, 1});
p.push_back(seg{xb, ya, yb, -1});
raw.push_back(ya), raw.push_back(yb);
}
sort(p.begin(), p.end());
sort(raw.begin(), raw.end());
raw.resize(unique(raw.begin(), raw.end()) - raw.begin());
int tot = raw.size() - 1;
root = build(0, tot);
modify(val(p[0][1]), val(p[0][2]) - 1, p[0][3], root);
int res = 0;
for (int i = 1; i < p.size(); i++) {
res += (p[i][0] - p[i - 1][0]) * root->value;
modify(val(p[i][1]), val(p[i][2]) - 1, p[i][3], root);
}
cout << res;
return 0;
}
P4588 [TJOI2018] 数学计算
做法很多,这里提供一个用线段树解决的思路。用一个大小为\(Q\)的数组,初始全为\(1\),操作 \(1\) 就是第\(i\)为修改为\(m\),操作\(2\)就是第\(pos\)位修改为\(1\),然后求一个区间乘积就好。
#include<bits/stdc++.h>
using namespace std;
using i32 = int32_t;
using i64 = long long;
#define int i64
const int inf = INT_MAX / 2;
int mod;
struct Node {
int l, r, val;
Node *left, *right;
Node(int l, int r, int val, Node *left = nullptr, Node *right = nullptr)
: l(l), r(r), val(val), left(left), right(right) {};
} *root;
Node *build(int l, int r) {
if (l == r) return new Node(l, r, 1);
int mid = (l + r) / 2;
auto left = build(l, mid), right = build(mid + 1, r);
return new Node(l, r, 1, left, right);
}
void modify(int p, int val, Node *cur) {
if (p < cur->l or p > cur->r) return;
if (cur->l == cur->r) {
cur->val = val % mod;
return;
}
int mid = (cur->l + cur->r) / 2;
if (p <= mid) modify(p, val, cur->left);
if (p > mid) modify(p, val, cur->right);
cur->val = (cur->left->val * cur->right->val) % mod;
return;
}
using vi = vector<int>;
void debug(Node *cur) {
cerr << cur->l << " " << cur->r << " " << cur->val << "\n";
if (cur->left != nullptr)debug(cur->left), debug(cur->right);
}
void solve() {
int n;
cin >> n >> mod;
root = build(1, n);
for (int i = 1, opt, x; i <= n; i++) {
cin >> opt >> x;
if (opt == 1) modify(i, x, root);
else modify(x, 1, root);
cout << root->val << "\n";
}
return;
}
i32 main() {
ios::sync_with_stdio(false), cin.tie(nullptr);
int TC;
for (cin >> TC; TC; TC--)
solve();
return 0;
}
D. Traffic Jams in the Land
https://codeforces.com/problemset/problem/498/D
一个很有用的性质是\(lcm (2,3,4,5,6) = 60\),因此对于到达时间我们模\(60\)就好了。
然后对于线段树节点的信息\(\{l,r,val[60]\}\)就可以用\(val[i]\)表示\(i\)的时间到达\(l\)并走到\(r\)所需要的代价。这样线段是维护就变成了\(O(60\log n)\)
#include <bits/stdc++.h>
using namespace std;
using vi = vector<int>;
struct Node {
int l, r;
vector<int> val;
Node *left, *right;
Node(int l, int r, int x) : l(l), r(r), val(60) {
left = right = nullptr;
for (int i = 0; i < 60; i++)
val[i] = 1 + (i % x == 0);
}
Node(int l, int r, Node *left, Node *right)
: l(l), r(r), left(left), right(right), val(60) {}
} *root;
void maintain(Node *cur) {
for (int i = 0, t; i < 60; i++) {
t = cur->left->val[i];
cur->val[i] = t + cur->right->val[(i + t) % 60];
}
return;
}
Node *build(int l, int r, const vi &v) {
if (l == r) return new Node(l, r, v[l]);
int mid = (l + r) / 2;
Node *left = build(l, mid, v);
Node *right = build(mid + 1, r, v);
Node *cur = new Node(l, r, left, right);
maintain(cur);
return cur;
}
void modify(int p, int delta, Node *cur) {
if (p < cur->l or cur->r < p) return;
if (cur->l == cur->r) {
for (int i = 0; i < 60; i++)
cur->val[i] = 1 + (i % delta == 0);
return;
}
int mid = (cur->l + cur->r) / 2;
if (p <= mid) modify(p, delta, cur->left);
if (p > mid) modify(p, delta, cur->right);
maintain(cur);
return;
}
int query(int l, int r, int start, Node *cur) {
if (l > cur->r or r < cur->l) return 0;
if (l <= cur->l and cur->r <= r) return cur->val[start];
int mid = (cur->l + cur->r) / 2, ans = 0;
if (l <= mid) ans += query(l, r, (start + ans) % 60, cur->left);
if (r > mid) ans += query(l, r, (start + ans) % 60, cur->right);
return ans;
}
int main() {
ios::sync_with_stdio(false), cin.tie(nullptr);
int n;
cin >> n;
vi a(n + 1);
for (int i = 1; i <= n; i++) cin >> a[i];
root = build(1, n, a);
int m;
cin >> m;
for (char opt; m; m--) {
cin >> opt;
if (opt == 'C') {
int x, d;
cin >> x >> d;
modify(x, d, root);
} else {
int x, y;
cin >> x >> y, y--;
cout << query(x, y, 0, root) << "\n";
}
}
return 0;
}
F - Vacation Query
要用线段树实现区间翻转和区间最长连续子段。
其实可以直接维护出最长的连续\(1\)的同时维护出最长连续\(0\),当需要翻转的时候直接交换就好了。
#include<bits/stdc++.h>
using namespace std;
using i32 = int32_t;
using i64 = long long;
#define int i64
const int inf = INT_MAX / 2;
using vi = vector<int>;
struct Info {
array<int, 2> pre, suf, val, len;
Info() : pre{0, 0}, suf{0, 0}, val{0, 0}, len{0, 0} {}
Info(int x) {
pre[x] = suf[x] = val[x] = 1;
pre[x ^ 1] = suf[x ^ 1] = val[x ^ 1] = 0;
len[x] = 1, len[x ^ 1] = -inf;
}
void swap() {
std::swap(pre[0], pre[1]);
std::swap(suf[0], suf[1]);
std::swap(val[0], val[1]);
std::swap(len[0], len[1]);
}
friend Info operator+(Info x, Info y) {
Info ans;
for (int i = 0; i < 2; i++) {
ans.pre[i] = std::max(x.pre[i], x.len[i] + y.pre[i]);
ans.suf[i] = std::max(y.suf[i], x.suf[i] + y.len[i]);
ans.val[i] = std::max({x.val[i], y.val[i], x.suf[i] + y.pre[i]});
ans.len[i] = x.len[i] + y.len[i];
}
return ans;
}
};
struct Node {
i32 l, r, tag;
Info info;
Node *left, *right;
Node() {};
Node(int p, int v) : info(v) {
l = r = p, tag = 0;
left = right = nullptr;
}
Node(int l, int r, Node *left, Node *right)
: l(l), r(r), left(left), right(right) {
tag = 0;
info = left->info + right->info;
}
};
void maintain(Node *cur) {
if (cur->left == nullptr) return;
cur->info = cur->left->info + cur->right->info;
}
void mark(Node *cur) {
cur->tag ^= 1;
cur->info.swap();
return;
}
void pushdown(Node *cur) {
if (cur->tag == 0) return;
if (cur->left != nullptr) {
mark(cur->left);
mark(cur->right);
}
cur->tag = 0;
return;
}
Node *build(i32 l, i32 r, const string &s) {
if (l == r) return new Node(l, s[l - 1] - '0');
i32 mid = (l + r) / 2;
auto left = build(l, mid, s), right = build(mid + 1, r, s);
return new Node(l, r, left, right);
}
Info query(i32 l, i32 r, Node *cur) {
if (cur == nullptr) return Info();
if (l > cur->r or r < cur->l) return Info();
if (l <= cur->l and cur->r <= r) return cur->info;
pushdown(cur);
return query(l, r, cur->left) + query(l, r, cur->right);
}
void modify(i32 l, i32 r, Node *cur) {
if (cur == nullptr) return;
if (l > cur->r or r < cur->l) return;
if (l <= cur->l and cur->r <= r) {
mark(cur);
return;
}
pushdown(cur);
modify(l, r, cur->left);
modify(l, r, cur->right);
maintain(cur);
}
i32 main() {
ios::sync_with_stdio(false), cin.tie(nullptr);
i32 n, q;
string s;
cin >> n >> q >> s;
Node *root = build(1, n, s);
for (i32 op, l, r; q; q--) {
cin >> op >> l >> r;
if (op == 1) modify(l, r, root);
else cout << query(l, r, root).val[1] << "\n";
}
return 0;
}