9.20 Magical multisets
题意
有\(n\)个可重集,编号为\(1\to n\),开始时都是空的,现在有两种操作
- 将元素\(x\)加入编号为\([l,r]\)的集合中,若集合中原本就有元素\(x\),那么该集合中所有元素的个数都会翻倍
- 询问编号为\([l,r]\)集合中元素个数的和,取模\(998244353\)
解法
\(set\)维护区间
对每个颜色开一个\(set\),存储区间(左端点,右端点)
每次加入一个区间,对于区间的交,在线段树对应的区间进行区间乘\(2\)操作;对于原本是空的区间,在对应的区间进行区间加\(1\)的操作
为了保证复杂度,每次要把区间进行合并(可以注意到我们这里只注意区间内有无\(x\)这个数,所以直接合并区间长度即可)
很巧妙啊,第一次接触这种题目
代码
#include <set>
#include <cstdio>
using namespace std;
const int N = 4e5 + 10;
const int mod = 998244353;
int read();
struct seg {
int l, r;
seg(int _l, int _r) : l(_l), r(_r) {}
bool operator < (const seg& _t) const { return r < _t.r; }
};
int n, q;
set<seg> st[N];
typedef set<seg>::iterator iter;
struct SegTree {
#define ls(x) x << 1
#define rs(x) x << 1 | 1
struct node {
int val, add, mul;
node() : val(0), add(0), mul(1) {}
} t[N << 2];
void addtag(int x, int l, int r, int v) {
t[x].val = (t[x].val + 1LL * (r - l + 1) * v % mod) % mod;
t[x].add = (t[x].add + v) % mod;
}
void multag(int x, int l, int r, int v) {
t[x].val = 1LL * t[x].val * v % mod;
t[x].add = 1LL * t[x].add * v % mod;
t[x].mul = 1LL * t[x].mul * v % mod;
}
void pushdown(int x, int l, int r) {
int mid = l + r >> 1;
if (t[x].mul != 1) {
multag(ls(x), l, mid, t[x].mul);
multag(rs(x), mid + 1, r, t[x].mul);
t[x].mul = 1;
}
if (t[x].add) {
addtag(ls(x), l, mid, t[x].add);
addtag(rs(x), mid + 1, r, t[x].add);
t[x].add = 0;
}
}
void modify(int x, int l, int r, int ql, int qr, int fl, int v) {
if (ql <= l && r <= qr)
return (!fl) ? addtag(x, l, r, v) : multag(x, l, r, v), void();
int mid = l + r >> 1;
pushdown(x, l, r);
if (ql <= mid)
modify(ls(x), l, mid, ql, qr, fl, v);
if (qr > mid)
modify(rs(x), mid + 1, r, ql, qr, fl, v);
t[x].val = (t[ls(x)].val + t[rs(x)].val) % mod;
}
int query(int x, int l, int r, int ql, int qr) {
if (ql <= l && r <= qr)
return t[x].val;
int mid = l + r >> 1, res = 0;
pushdown(x, l, r);
if (ql <= mid)
res = (res + query(ls(x), l, mid, ql, qr)) % mod;
if (qr > mid)
res = (res + query(rs(x), mid + 1, r, ql, qr)) % mod;
return res;
}
#undef ls
#undef rs
} tr;
void update(int L, int R, int x) {
iter it = st[x].lower_bound(seg(L, L));
if (it == st[x].end() || (it -> l) > R) {
tr.modify(1, 1, n, L, R, 0, 1);
st[x].insert(seg(L, R));
return;
}
if ((it -> l) <= L && (it -> r) >= R) {
tr.modify(1, 1, n, L, R, 1, 2);
return;
}
iter bg = it;
int sl = L;
while (it != st[x].end() && (it -> l) <= R) {
if ((it -> l) > sl)
tr.modify(1, 1, n, sl, (it -> l) - 1, 0, 1);
tr.modify(1, 1, n, max(it -> l, L), min(it -> r, R), 1, 2);
sl = (it -> r) + 1;
++it;
}
--it;
if (sl <= R)
tr.modify(1, 1, n, sl, R, 0, 1);
seg New = seg(min(L, bg -> l), max(R, it -> r));
while (bg != it)
st[x].erase(bg++);
st[x].erase(it);
st[x].insert(New);
}
int main() {
// freopen("multiset.in", "r", stdin);
// freopen("multiset.out", "w", stdout);
n = read(), q = read();
int op, l, r;
while (q--) {
op = read(), l = read(), r = read();
if (op == 1)
update(l, r, read());
else
printf("%d\n", tr.query(1, 1, n, l, r));
}
return 0;
}
#define gc getchar
int read() {
int x = 0, c = gc();
while (c < '0' || c > '9') c = gc();
while (c >= '0' && c <= '9') x = x * 10 + c - 48, c = gc();
return x;
}