Loading

P7224 [RC-04] 子集积 (背包 dp + 复杂度优化)

P7224 [RC-04] 子集积

背包 dp + 复杂度优化

考虑 dp。容易想到背包 dp,设 \(f_{i,j}\) 表示考虑了前 \(i\) 个,当前乘积为 \(j\) 的方案数。枚举 \(a_i\) 的倍数转移。

复杂度 \(O(\sum\limits_{i=1}^n\frac{m}{a_i})\)。如果 \(a_i\) 互不相同,那么近似于 \(O(m\ln m)\)

如果还想要这样的复杂度,可以考虑相同的部分能不能同时处理。假设现在 \(a_i\)\(k\) 个,那么会组成 \(k\) 个不同的 \(a_i\) 的乘积(如 \(a_i\)\(a_i^2\)\(a_i^k\))。将这 \(k\) 个数放入背包的物品中,对于物品 \(a_i^j\),有 \(C(k,j)\) 的系数,每次转移同样是 \(\frac{m}{a_i^j}\) 的复杂度。

那么从原来每个相同的 \(a_i\) 都是 \(O(\frac{m}{a_i})\) 的复杂度,到现在所有相同的 \(a_i\) 总复杂度\(O(\sum\limits_{j=1}^k\frac{m}{a_{i}^j})\),由于下面是指数增长,所以近似于 \(O(\frac{m}{a_i})\)

需要注意的是,对于 \(a_i=1\) 的部分需要单独处理,最后将每个状态 \(f_i\times 2^{cnt_1}\) 即可。

复杂度 \(O(m\ln m)\)

#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define mk std::make_pair
#define fi first
#define se second
#define pb push_back

using i64 = long long;
using ull = unsigned long long;
const i64 iinf = 0x3f3f3f3f, linf = 0x3f3f3f3f3f3f3f3f;
const int N = 1e6 + 10, mod = 998244353;
int n, m, cnt[N];
i64 fac[N], inv[N], a[N], f[N];
i64 qpow(i64 a, i64 b) {
	i64 ret = 1;
	while(b) {
		if(b & 1) ret = ret * a % mod;
		a = a * a % mod;
		b >>= 1;
	}
	return ret;
}
void init() {
	fac[0] = 1;
	for(int i = 1; i <= n; i++) fac[i] = fac[i - 1] * i % mod;

	inv[n] = qpow(fac[n], mod - 2);
	for(int i = n - 1; i >= 0; i--) inv[i] = inv[i + 1] * (i + 1) % mod;
}
i64 C(i64 n, i64 m) {
	if(n < m) return 0;
	return fac[n] * inv[m] % mod * inv[n - m] % mod; 
}
int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    
	std::cin >> n >> m;

	init();
	i64 ans = qpow(2, n);

	for(int i = 1; i <= n; i++) {
		std::cin >> a[i];
		cnt[a[i]]++;
	}
	std::sort(a + 1, a + n + 1);
	n = std::unique(a + 1, a + n + 1) - a - 1;

	f[1] = 1;
	for(int i = 1; i <= n; i++) {
		if(a[i] == 1) continue;

		i64 val = 1;
		for(int j = m / a[i]; j >= 1; j--) {
			val = 1;
			for(int k = 1; k <= cnt[a[i]]; k++) {
				val *= a[i];
				if(j * val > m) break;
				f[j * val] = (f[j * val] + f[j] * C(cnt[a[i]], k) % mod) % mod;
			}
		}
	}

	i64 pw = qpow(2, cnt[1]);
	for(int i = 1; i <= m; i++) {
		ans = (ans - f[i] * pw % mod + mod) % mod;
	}

	std::cout << ans << "\n";
	return 0;
}
posted @ 2024-07-06 21:55  Fire_Raku  阅读(1)  评论(0编辑  收藏  举报