【数据结构】线段树-区间修改
新模板,使用了模板和默认值、幺元等,分离node中的数据维护部分。当维护的数据与l,r有关时,节点内需要存l,r,可能导致空间开销比较大。当query的merge的成本太高时,不能用下面这个模板,必须用旧模板。旧模板中query的流程和每个节点中的数据相关,也就是说query是可以理解sum[u]和sum[ls], sum[rs]的,方便进行短路运算以及与l,r有关的操作(如CF-1902D)以及线段树二分等需要控制短路计算左右子树的问题。
新模板适用于常规的加法和、最小值、最大值、区间反转、字符串哈希等,特征为:对query的合并是O1/Ologn的。使用新模板把数据的操作全部交给node自行计算并合并,然后即可在query中取出。线段树本身不理解node中存储的数据。
template<typename>
struct default_value;
template<>
struct default_value<int> {
static constexpr int empty_value = 0;
static constexpr int min_value = -INF;
static constexpr int max_value = INF;
};
template<>
struct default_value<ll> {
static constexpr ll empty_value = 0LL;
static constexpr ll min_value = -LINF;
static constexpr ll max_value = LINF;
};
static constexpr int calc_segment_tree_size (int n) {
int res = 1;
while (res < (n << 1)) {
res <<= 1;
}
return res;
}
struct segment_tree {
#define ls (u << 1)
#define rs (u << 1 | 1)
#define mid ((l + r) >> 1)
static const int MAXN = 2e5 + 10;
static constexpr int MAXN_SEGMENT_TREE_SIZE = calc_segment_tree_size (MAXN);
using value_type = ll;
static const value_type empty_value = default_value<value_type>::empty_value;
static const value_type min_value = default_value<value_type>::min_value;
static const value_type max_value = default_value<value_type>::max_value;
int n;
struct node_type {
int id;
int len;
value_type tag;
value_type sum;
value_type mi;
value_type mx;
node_type (int u = 0, int l = 0, int r = 0) {
id = u;
len = (r - l + 1);
tag = empty_value;
sum = empty_value;
mi = max_value;
mx = min_value;
}
node_type (int u, int l, int r, value_type value) {
id = u;
len = (r - l + 1);
tag = empty_value;
sum = value;
mi = value;
mx = value;
}
void merge (const node_type& left_node, const node_type& right_node) {
len = left_node.len + right_node.len;
sum = left_node.sum + right_node.sum;
mi = min (left_node.mi, right_node.mi);
mx = max (left_node.mx, right_node.mx);
}
// 加法
void add (value_type value) {
tag += value;
sum += 1LL * len * value;
mi += value;
mx += value;
}
// 设值
// void set (value_type value) {
// tag = value;
// sum = 1LL * len * value;
// mi = value;
// mx = value;
// }
} node[MAXN_SEGMENT_TREE_SIZE];
void push_up (int u, int l, int r) {
node[u].merge (node[ls], node[rs]);
}
void push_down (int u, int l, int r) {
value_type t = node[u].tag;
if (t != empty_value) {
// 区间加法
node[ls].add (t);
node[rs].add (t);
// 区间设值
// node[ls].set (t);
// node[rs].set (t);
// 区间设值的懒标记要用特殊标记
node[u].tag = empty_value;
}
}
void build (int u, int l, int r, value_type* a) {
if (l == r) {
if (a == nullptr) {
node[u] = node_type (u, l, r);
} else {
node[u] = node_type (u, l, r, a[l]);
}
return;
}
build (ls, l, mid, a);
build (rs, mid + 1, r, a);
push_up (u, l, r);
}
// 区间加法
void add (int u, int l, int r, int L, int R, value_type value) {
if (L > R || L > r || R < l) {
return;
}
if (L <= l && r <= R) {
node[u].add (value);
return;
}
push_down (u, l, r);
add (ls, l, mid, L, R, value);
add (rs, mid + 1, r, L, R, value);
push_up (u, l, r);
}
// 区间设值
// void add (int u, int l, int r, int L, int R, value_type value) {
// if (L > R || L > r || R < l) {
// return;
// }
// if (L <= l && r <= R) {
// node[u].set (value);
// return;
// }
// push_down (u, l, r);
// set (ls, l, mid, L, R, value);
// set (rs, mid + 1, r, L, R, value);
// push_up (u, l, r);
// }
node_type query (int u, int l, int r, int L, int R) {
if (L > R || L > r || R < l) {
return node_type();
}
if (L <= l && r <= R)
return node[u];
push_down (u, l, r);
node_type l_res = query (ls, l, mid, L, R);
node_type r_res = query (rs, mid + 1, r, L, R);
node_type res;
res.merge (l_res, r_res);
return res;
}
void Build (int n, value_type* a = nullptr) {
this->n = n;
build (1, 1, n, a);
}
// 区间加法
void Add (int L, int R, value_type value) {
add (1, 1, n, L, R, value);
}
// 区间设值
// void Set (int L, int R, value_type value) {
// set (1, 1, n, L, R, value);
// }
node_type Query (int L, int R) {
return query (1, 1, n, L, R);
}
value_type QueryMin (int L, int R) {
return Query (L, R).mi;
}
value_type QueryMax (int L, int R) {
return Query (L, R).mx;
}
value_type QuerySum (int L, int R) {
return Query (L, R).sum;
}
#undef ls
#undef rs
#undef mid
} st;
旧模板,比较熟悉比较好改动。
区间加法
struct SegmentTree {
#define ls (u << 1)
#define rs (u << 1 | 1)
#define mid ((l + r) >> 1)
static const int MAXN = 2e5 + 10;
static const ll LINF = 1e18 + 10;
ll tag[MAXN << 2];
ll mi[MAXN << 2];
ll mx[MAXN << 2];
ll sum[MAXN << 2];
void PushUp(int u) {
mi[u] = min(mi[ls], mi[rs]);
mx[u] = max(mx[ls], mx[rs]);
sum[u] = sum[ls] + sum[rs];
}
void PushDown(int u, int l, int r) {
ll t = tag[u];
if(t != 0) {
tag[ls] += t, mi[ls] += t, mx[ls] += t;
sum[ls] += 1LL * (mid - l + 1) * t;
tag[rs] += t, mi[rs] += t, mx[rs] += t;
sum[rs] += 1LL * (r - mid) * t;
tag[u] = 0;
}
}
void Build(int u, int l, int r) {
tag[u] = 0;
if(l == r) {
mi[u] = mx[u] = sum[u] = 0;
return;
}
Build(ls, l, mid);
Build(rs, mid + 1, r);
PushUp(u);
}
void UpdateAdd(int u, int l, int r, int L, int R, ll v) {
if(L <= l && r <= R) {
tag[u] += v, mi[u] += v, mx[u] += v;
sum[u] += 1LL * (r - l + 1) * v;
return;
}
PushDown(u, l, r);
if(L <= mid) UpdateAdd(ls, l, mid, L, R, v);
if(R >= mid + 1) UpdateAdd(rs, mid + 1, r, L, R, v);
PushUp(u);
}
ll QueryMin(int u, int l, int r, int L, int R) {
if(L <= l && r <= R)
return mi[u];
PushDown(u, l, r);
ll res = LINF;
if(L <= mid) res = min(res, QueryMin(ls, l, mid, L, R));
if(R >= mid + 1) res = min(res, QueryMin(rs, mid + 1, r, L, R));
return res;
}
ll QueryMax(int u, int l, int r, int L, int R) {
if(L <= l && r <= R)
return mx[u];
PushDown(u, l, r);
ll res = -LINF;
if(L <= mid) res = max(res, QueryMax(ls, l, mid, L, R));
if(R >= mid + 1) res = max(res, QueryMax(rs, mid + 1, r, L, R));
return res;
}
ll QuerySum(int u, int l, int r, int L, int R) {
if(L <= l && r <= R)
return sum[u];
PushDown(u, l, r);
ll res = 0;
if(L <= mid) res += QuerySum(ls, l, mid, L, R);
if(R >= mid + 1) res += QuerySum(rs, mid + 1, r, L, R);
return res;
}
#undef ls
#undef rs
#undef mid
} st;
区间设值
struct SegmentTree {
#define ls (u << 1)
#define rs (u << 1 | 1)
#define mid ((l + r) >> 1)
static const int MAXN = 2e5 + 10;
static const ll LINF = 1e18 + 10;
ll tag[MAXN << 2];
ll mi[MAXN << 2];
ll mx[MAXN << 2];
ll sum[MAXN << 2];
void PushUp(int u) {
mi[u] = min(mi[ls], mi[rs]);
mx[u] = max(mx[ls], mx[rs]);
sum[u] = sum[ls] + sum[rs];
}
void PushDown(int u, int l, int r) {
ll t = tag[u];
if(t != LINF) {
tag[ls] = mi[ls] = mx[ls] = t;
sum[ls] = 1LL * (mid - l + 1) * t;
tag[rs] = mi[rs] = mx[rs] = t;
sum[rs] = 1LL * (r - mid) * t;
tag[u] = LINF;
}
}
void Build(int u, int l, int r) {
tag[u] = LINF;
if(l == r) {
mi[u] = mx[u] = sum[u] = 0;
return;
}
Build(ls, l, mid);
Build(rs, mid + 1, r);
PushUp(u);
}
void UpdateSet(int u, int l, int r, int L, int R, ll v) {
if(L <= l && r <= R) {
tag[u] = mi[u] = mx[u] = v;
sum[u] = 1LL * (r - l + 1) * v;
return;
}
PushDown(u, l, r);
if(L <= mid) UpdateSet(ls, l, mid, L, R, v);
if(R >= mid + 1) UpdateSet(rs, mid + 1, r, L, R, v);
PushUp(u);
}
ll QueryMin(int u, int l, int r, int L, int R) {
if(L <= l && r <= R)
return mi[u];
PushDown(u, l, r);
ll res = LINF;
if(L <= mid) res = min(res, QueryMin(ls, l, mid, L, R));
if(R >= mid + 1) res = min(res, QueryMin(rs, mid + 1, r, L, R));
return res;
}
ll QueryMax(int u, int l, int r, int L, int R) {
if(L <= l && r <= R)
return mx[u];
PushDown(u, l, r);
ll res = -LINF;
if(L <= mid) res = max(res, QueryMax(ls, l, mid, L, R));
if(R >= mid + 1) res = max(res, QueryMax(rs, mid + 1, r, L, R));
return res;
}
ll QuerySum(int u, int l, int r, int L, int R) {
if(L <= l && r <= R)
return sum[u];
PushDown(u, l, r);
ll res = 0;
if(L <= mid) res += QuerySum(ls, l, mid, L, R);
if(R >= mid + 1) res += QuerySum(rs, mid + 1, r, L, R);
return res;
}
#undef ls
#undef rs
#undef mid
} st;
混合标记:
加法和乘法的混合标记
超省心版:
struct SegmentTree {
private:
#define ls (u << 1)
#define rs (u << 1 | 1)
#define mid ((l + r) >> 1)
static const int MAXN = 3e5 + 10;
static const ll LINF = 1e18 + 10;
int n;
struct Node {
ll sum, mi, ma;
ll len, tag;
Node() {
sum = 0, mi = 0, ma = 0;
len = 1, tag = 0;
return;
}
Node& operator=(ll val) {
sum = 1LL * len * val, mi = val, ma = val;
tag = 0;
return *this;
}
Node& operator+=(ll val) {
sum += 1LL * len * val, mi += val, ma += val;
tag += val;
return *this;
}
} node[MAXN << 2];
Node merge(const Node &x, const Node &y) {
Node res;
res.sum = x.sum + y.sum;
res.mi = min(x.mi, y.mi);
res.ma = min(x.ma, y.ma);
res.len = x.len + y.len;
res.tag = 0;
return move(res);
}
void pull(int u) {
node[u] = merge(node[ls], node[rs]);
return;
}
void push(int u) {
if (node[u].tag != 0) {
node[ls] += node[u].tag;
node[rs] += node[u].tag;
node[u].tag = 0;
}
return;
}
void iBuild(int u, int l, int r) {
if (l > r)
return;
if (l == r) {
node[u] = Node();
return;
}
iBuild(ls, l, mid);
iBuild(rs, mid + 1, r);
pull(u);
return;
}
void iUpdate(int u, int l, int r, int lpos, int rpos, ll val) {
if (l > r || lpos > r || rpos < l)
return;
if (lpos <= l && r <= rpos) {
node[u] += val;
return;
}
push(u);
iUpdate(ls, l, mid, lpos, rpos, val);
iUpdate(rs, mid + 1, r, lpos, rpos, val);
pull(u);
return;
}
Node iQuery(int u, int l, int r, int lpos, int rpos) {
if (l > r || lpos > r || rpos < l)
return Node();
if (lpos <= l && r <= rpos)
return node[u];
push(u);
Node resL = iQuery(ls, l, mid, lpos, rpos);
Node resR = iQuery(rs, mid + 1, r, lpos, rpos);
return merge(resL, resR);
}
#undef ls
#undef rs
#undef mid
public:
void build(int n) {
this->n = n;
iBuild(1, 1, n);
return;
}
void update(int lpos, int rpos, ll val) {
iUpdate(1, 1, n, lpos, rpos, val);
return;
}
Node query(int lpos, int rpos) {
return iQuery(1, 1, n, lpos, rpos);
}
ll querySum(int lpos, int rpos) {
return query(lpos, rpos).sum;
}
ll queryMin(int lpos, int rpos) {
return query(lpos, rpos).mi;
}
ll queryMax(int lpos, int rpos) {
return query(lpos, rpos).ma;
}
} st;