9.20 Magical multisets

题意

\(n\)个可重集,编号为\(1\to n\),开始时都是空的,现在有两种操作

  • 将元素\(x\)加入编号为\([l,r]\)的集合中,若集合中原本就有元素\(x\),那么该集合中所有元素的个数都会翻倍
  • 询问编号为\([l,r]\)集合中元素个数的和,取模\(998244353\)

解法

\(set\)维护区间

对每个颜色开一个\(set\),存储区间(左端点,右端点)

每次加入一个区间,对于区间的交,在线段树对应的区间进行区间乘\(2\)操作;对于原本是空的区间,在对应的区间进行区间加\(1\)的操作

为了保证复杂度,每次要把区间进行合并(可以注意到我们这里只注意区间内有无\(x\)这个数,所以直接合并区间长度即可)

很巧妙啊,第一次接触这种题目

代码

#include <set>
#include <cstdio>

using namespace std;

const int N = 4e5 + 10;
const int mod = 998244353;

int read();

struct seg {
	int l, r;
	seg(int _l, int _r) : l(_l), r(_r) {}
	bool operator < (const seg& _t) const { return r < _t.r; }
};

int n, q;

set<seg> st[N];
typedef set<seg>::iterator iter;

struct SegTree {
#define ls(x) x << 1	
#define rs(x) x << 1 | 1
	
	struct node {
		int val, add, mul;
		node() : val(0), add(0), mul(1) {}
	} t[N << 2];
	
	void addtag(int x, int l, int r, int v) {
		t[x].val = (t[x].val + 1LL * (r - l + 1) * v % mod) % mod;
		t[x].add = (t[x].add + v) % mod;
	}
	
	void multag(int x, int l, int r, int v) {
		t[x].val = 1LL * t[x].val * v % mod;
		t[x].add = 1LL * t[x].add * v % mod;		
		t[x].mul = 1LL * t[x].mul * v % mod;
	}
	
	void pushdown(int x, int l, int r) {
		int mid = l + r >> 1;
		if (t[x].mul != 1) {
			multag(ls(x), l, mid, t[x].mul);
			multag(rs(x), mid + 1, r, t[x].mul);
			t[x].mul = 1;
		}
		if (t[x].add) {
			addtag(ls(x), l, mid, t[x].add);
			addtag(rs(x), mid + 1, r, t[x].add);
			t[x].add = 0;	
		}
	}
	
	void modify(int x, int l, int r, int ql, int qr, int fl, int v) {
		if (ql <= l && r <= qr) 
			return (!fl) ? addtag(x, l, r, v) : multag(x, l, r, v), void();
		int mid = l + r >> 1;
		pushdown(x, l, r);
		if (ql <= mid)
			modify(ls(x), l, mid, ql, qr, fl, v);
		if (qr > mid)
			modify(rs(x), mid + 1, r, ql, qr, fl, v);
		t[x].val = (t[ls(x)].val + t[rs(x)].val) % mod; 
	}
	
	int query(int x, int l, int r, int ql, int qr) {
		if (ql <= l && r <= qr)
			return t[x].val;
		int mid = l + r >> 1, res = 0;
		pushdown(x, l, r);
		if (ql <= mid)	
			res = (res + query(ls(x), l, mid, ql, qr)) % mod;
		if (qr > mid)
			res = (res + query(rs(x), mid + 1, r, ql, qr)) % mod;
		return res;
	}

#undef ls
#undef rs
} tr;

void update(int L, int R, int x) {
	iter it = st[x].lower_bound(seg(L, L));
	
	if (it == st[x].end() || (it -> l) > R) {
		tr.modify(1, 1, n, L, R, 0, 1);
		st[x].insert(seg(L, R));
		return;
	}
	
	if ((it -> l) <= L && (it -> r) >= R) {
		tr.modify(1, 1, n, L, R, 1, 2);
		return;
	}
	
	iter bg = it;
	int sl = L;
	while (it != st[x].end() && (it -> l) <= R) {
		if ((it -> l) > sl)
			tr.modify(1, 1, n, sl, (it -> l) - 1, 0, 1);
		tr.modify(1, 1, n, max(it -> l, L), min(it -> r, R), 1, 2);
		sl = (it -> r) + 1;
		++it;
	}
	--it;
	
	if (sl <= R)
		tr.modify(1, 1, n, sl, R, 0, 1);	
	
	seg New = seg(min(L, bg -> l), max(R, it -> r));
	while (bg != it)
		st[x].erase(bg++);
	st[x].erase(it);
	
	st[x].insert(New);
}

int main() {
	
//	freopen("multiset.in", "r", stdin);
//	freopen("multiset.out", "w", stdout);
	
	n = read(), q = read();

	int op, l, r;	
	while (q--) {
		op = read(), l = read(), r = read();
		if (op == 1)
			update(l, r, read());
		else
			printf("%d\n", tr.query(1, 1, n, l, r));
	}
	
	return 0;
}

#define gc getchar
int read() {
	int x = 0, c = gc();	
	while (c < '0' || c > '9')    c = gc();
	while (c >= '0' && c <= '9')  x = x * 10 + c - 48, c = gc();
	return x;
}
posted @ 2019-09-21 20:00  四季夏目天下第一  阅读(148)  评论(0编辑  收藏  举报