【YBT2023寒假Day10 A】集合比较(数学)(启发式分裂)

集合比较

题目链接:YBT2023寒假Day10 A

题目大意

给你一个长度为 n 的排列 p,定义两个大小为 n 不可重集合的比较方式是先比较各自第 p1 小的元素,如果相同比 p2,以此类推。
给你一个大小为 n 不可重集合,问你有多少个大小为 n 的不可重集合比给出的字典序小。

思路

考虑把集合排序,变成序列。

考虑模拟字典序的比较方式。
那每次你就看哪个位置,如果不一样就可以直接出结果,否则就要继续看下一个。

那在你看第 \(p_i\) 的时候,我们已经确定了 \(p_{1}\sim p_{i-1}\) 位置的值。
那就分成了若干个段,每一段有要填的数的个数 \([l,r]\),也有值域的选择 \([A_{l-1}+1,A_{r+1}-1]\),然后根据不可重,不难用组合数算出每一段的方案数(\(\binom{A_{r+1}-A_{l-1}-1}{r-l+1}\)

然后我们考虑加上 \(p_i\) 位置相等或不等的限制。
如果相等,那我们这个位置选的值是固定的,然后就把一段拆成了两段,维护上面所有方案数的乘积即可。
如果不相等,会发现我们至少需要枚举 \(p_i\) 这个位置你选了什么。
但是这样就 \(O(n^2)\) 了。

发现这个过程是分裂区间的过程,而且我们还有一种算法是容斥,用所有的方案减去枚举不合法的 \(p_i\) 位置选的值的贡献。
然后你每次选复杂度小的那边算就可以做到 \(O(n\log n)\) 了。

代码

#include<set>
#include<cstdio>
#include<algorithm>
#define mo 998244353

using namespace std;

const int N = 2e5 + 100;
int n, m, b[N], p[N], jc[N], inv[N], invs[N];

int add(int x, int y) {return x + y >= mo ? x + y - mo : x + y;}
int dec(int x, int y) {return x < y ? x - y + mo : x - y;}
int mul(int x, int y) {return 1ll * x * y % mo;}
int C(int n, int m) {
	if (n < 0 || m < 0 || n < m) return 0;
	return mul(mul(jc[n], invs[m]), invs[n - m]);
}
int ksm(int x, int y) {
	int re = 1;
	while (y) {
		if (y & 1) re = mul(re, x);
		x = mul(x, x); y >>= 1;
	}
	return re;
}

struct node {
	int l, r, x;
};
set <node> s;

bool operator <(node x, node y) {
	if (x.r != y.r) return x.r < y.r;
	return x.l < y.l;
}

int main() {
	freopen("dict.in", "r", stdin);
	freopen("dict.out", "w", stdout);
	
	jc[0] = 1; for (int i = 1; i < N; i++) jc[i] = mul(jc[i - 1], i);
	inv[0] = inv[1] = 1; for (int i = 2; i < N; i++) inv[i] = mul(inv[mo % i], mo - mo / i);
	invs[0] = 1; for (int i = 1; i < N; i++) invs[i] = mul(invs[i - 1], inv[i]);
	
	scanf("%d %d", &n, &m);
	for (int i = 1; i <= n; i++) scanf("%d", &b[i]);
	for (int i = 1; i <= n; i++) scanf("%d", &p[i]);
	sort(b + 1, b + n + 1);
	
	b[0] = 0; b[n + 1] = m + 1;
	int di = C(m, n), ans = 0;
	s.insert((node){1, n, di});
	for (int i = 1; i <= n; i++) {
		set <node> ::iterator it = s.lower_bound((node){0, p[i], 0});
		node big = *it; s.erase(it); di = mul(di, ksm(big.x, mo - 2));
		
		int now = 0;
		if (b[p[i]] - b[big.l - 1] - 1 <= b[big.r + 1] - b[p[i]]) {
			for (int j = b[big.l - 1] + 1; j <= b[p[i]] - 1; j++)
				now = add(now, mul(C(j - b[big.l - 1] - 1, p[i] - big.l), C(b[big.r + 1] - j - 1, big.r - p[i])));
		}
		else {
			now = big.x;
			for (int j = b[p[i]]; j <= b[big.r + 1] - 1; j++)
				now = dec(now, mul(C(j - b[big.l - 1] - 1, p[i] - big.l), C(b[big.r + 1] - j - 1, big.r - p[i])));
		}
		ans = add(ans, mul(now, di));
		
		if (big.l <= p[i] - 1) {
			s.insert((node){big.l, p[i] - 1, C(b[p[i]] - b[big.l - 1] - 1, p[i] - big.l)});
			di = mul(di, C(b[p[i]] - b[big.l - 1] - 1, p[i] - big.l));
		}
		if (p[i] + 1 <= big.r) {
			s.insert((node){p[i] + 1, big.r, C(b[big.r + 1] - b[p[i]] - 1, big.r - p[i])});
			di = mul(di, C(b[big.r + 1] - b[p[i]] - 1, big.r - p[i]));
		}
	}
	printf("%d", ans);
	
	return 0;
}
posted @ 2023-02-20 00:31  あおいSakura  阅读(33)  评论(0编辑  收藏  举报