[洛谷] P4389 付公主的背包 题解

[洛谷] P4389 付公主的背包 题解

对于 \(n,m\le 3000\) 直接 \(O(nm)\) 背包。

对于 \(n,m\le 1e5\) ,可以考虑构建 \(OGF\) 进行优化。考虑令 \(ans(x)\)\(x^i\) 的系数为 \(s=i\) 时的方案数,那么有

\[ans_i(x)=\sum_{a_1,\cdots,a_n\ge 0} [\sum_{j=0}^{n}a_jv_j==i] \]

因此可以考虑对于每一个物品 \(i\) 构建 \(OGF\)\(f_i(x)=\sum_{j\ge0}x^{v_j\times i}\) ,则有:

\[\begin{aligned} ans(x)&=\prod_{i=1}^{n}f_i(x)\\ &=\prod_{i=1}^{n}\frac{1}{1-x^{v_i}}\\ \ln(ans(x))&=-\sum_{i=1}^{n}\ln(1-x^{v_i})\\ &=-\sum_{i=1}^{n}\int\ln'(1-x^{x_i})\\ &=-\sum_{i=1}^{n}\int\frac{-v_ix^{v_i-1}}{1-x^{v_i}}\\ &=\sum_{i=1}^{n}\int\sum_{j\ge0}v_ix\times ^{j\times v_i}\times x^{v_i-1}\\ &=\sum_{i=1}^{n}\int\sum_{j\ge0}v_ix^{(j+1)\times v_i-1}\\ &=\sum_{i=1}^{n}\sum_{j\ge0}\frac{v_ix^{(j+1)\times v_i}}{(j+1)\times v_i}\\ &=\sum_{i=1}^{n}\sum_{j\ge1}\frac{(x^{v_i})^{j}}{j}\\ &=\sum_{i=1}^{n}\sum_{j\ge1,v_i\times j \le m} \frac{x^{v_i\times j}}{j}\\ \end{aligned} \]

发现变成调和级数,复杂度 \(O(m\ln m)\),但要注意对于所有 \(v_i\) 相等的物品要一起计算才对,即如果记 \(cnt_i\) 表示物品体积为 \(i\) 的物品的数量,那么答案为:

\[ans(x)=\sum_{i=1}^{m}cnt_i\times\sum_{j\ge1,i\times j \le m}\frac{x^{i\times j}}{j} \]

代码附上:

#include <bits/stdc++.h>
using namespace std;
const int N = 3e5 + 30;
const int M = 1e5;
const int mod = 998244353;
const int pr = 3;
const int ig = 332748118;

inline int add(int x, int y) {
	x += y;
	return x >= mod ? x - mod : x;
}

inline int del(int x, int y) {
	x -= y;
	return x < 0 ? x + mod : x;
}

inline int read() {
	char ch = getchar(); int x = 0;
	while(!isdigit(ch)) ch = getchar();
	while(isdigit(ch)) {x = (x << 3) + (x << 1) + ch - 48; ch = getchar();}
	return x;
}

void write(int x) {
	if(!x) return;
	write(x / 10); putchar(x % 10 + 48);
}

void print(int x) {
	if(!x) return puts("0"), void();
	write(x); putchar('\n');
}

int qpow(int a, int b) {
	int res = 1;
	while(b) {
		if(b & 1) res = 1ll * res * a % mod;
		a = 1ll * a * a % mod;
		b >>= 1;
	}
	return res;
}

struct node {
	int n, mx = 1, f[N], g[N];
	void NTT(int *a, int len) {
		int x, y, g, pw;
		for(int j = len >> 1; j >= 1; j >>= 1) {
			g = qpow(pr, (mod - 1) / (j << 1));
			for(int i = 0; i < len; i += (j << 1)) {
				pw = 1;
				for(int k = 0; k < j; ++k, pw = 1ll * pw * g % mod) {
					x = a[i + k]; y = a[i + j + k];
					a[i + k] = add(x, y);
					a[i + j + k] = 1ll * pw * del(x, y) % mod;
				}
			}
		}
	}
	
	void INTT(int *a, int len) {
		int x, y, g, pw;
		for(int j = 1; j < len; j <<= 1) {
			g = qpow(ig, (mod - 1) / (j << 1));
			for(int i = 0; i < len; i += (j << 1)) {
				pw = 1;
				for(int k = 0; k < j; ++k, pw = 1ll * pw * g % mod) {
					x = a[i + k]; y = 1ll * pw * a[i + j + k] % mod;
					a[i + k] = add(x, y);
					a[i + j + k] = del(x, y);
				}
			}
		}
		int Inv = qpow(len, mod - 2);
		for(int i = 0; i < len; ++i) a[i] = 1ll * a[i] * Inv % mod;
	}
	
	void mul(int *a, int *b, int len) {
		NTT(a, len); NTT(b, len);
		for(int i = 0; i < len; ++i) a[i] = 1ll * a[i] * b[i] % mod;
		INTT(a, len);
	}
	
	void INV(int *a, int *b, int len) {
		static int A[N], B[N];
		for(int i = 0; i < len; ++i) A[i] = a[i], a[i] = 0;
		for(int i = 0; i < len; ++i) B[i] = b[i], b[i] = 0;
		B[0] = qpow(A[0], mod - 2);
		for(int j = 2; j < len; j <<= 1) {
			for(int i = 0; i < (j << 1); ++i) a[i] = b[i] = 0;
			for(int i = 0; i < j; ++i) a[i] = A[i];
			for(int i = 0; i < (j >> 1); ++i) b[i] = B[i];
			mul(a, b, j << 1);
			for(int i = j; i < (j << 1); ++i) a[i] = 0;
			for(int i = 0; i < j; ++i) a[i] = (mod - a[i]) % mod;
			a[0] = add(a[0], 2);
			NTT(a, j << 1);
			for(int i = 0; i < (j << 1); ++i) a[i] = 1ll * a[i] * b[i] % mod;
			INTT(a, j << 1);
			for(int i = 0; i < j; ++i) B[i] = a[i];
		}
		for(int i = 0; i < len; ++i) a[i] = A[i], A[i] = 0;
		for(int i = 0; i < len; ++i) b[i] = B[i], B[i] = 0;
	}
	
	void LN(int *a, int *b, int len) {
		static int A[N], B[N];
		for(int i = 0; i < len; ++i) A[i] = a[i], a[i] = 0;
		for(int i = 0; i < len; ++i) B[i] = b[i], b[i] = 0;
		INV(A, B, len);
		for(int i = 0; i < len; ++i) b[i] = B[i];
		for(int i = 0; i < (len >> 1) - 1; ++i) a[i] = 1ll * A[i + 1] * (i + 1) % mod;
		for(int i = (len >> 1); i < len; ++i) a[i] = b[i] = 0;
		mul(a, b, len);
		for(int i = 1; i < (len >> 1); ++i) B[i] = 1ll * a[i - 1] * qpow(i, mod - 2) % mod;
		B[0] = 0;
		for(int i = 0; i < len; ++i) a[i] = A[i], A[i] = 0;
		for(int i = 0; i < len; ++i) b[i] = B[i], B[i] = 0;
	}
	
	void EXP(int *a, int *b, int len) {
		static int A[N], B[N];
		for(int i = 0; i < len; ++i) A[i] = a[i], a[i] = 0;
		for(int i = 0; i < len; ++i) B[i] = b[i], b[i] = 0;
		B[0] = 1;
		for(int j = 2; j < len; j <<= 1) {
			for(int i = 0; i < (j << 1); ++i) a[i] = b[i] = 0;
			for(int i = 0; i < j; ++i) a[i] = B[i];
			LN(a, b, j << 1);
			for(int i = j; i < (j << 1); ++i) b[i] = 0;
			for(int i = 0; i < j; ++i) b[i] = del(A[i], b[i]);
			b[0] = add(b[0], 1);
			mul(a, b, j << 1);
			for(int i = 0; i < j; ++i) B[i] = a[i];
			for(int i = j; i < (j << 1); ++i) B[i] = 0;
		}
		for(int i = 0; i < len; ++i) a[i] = A[i], A[i] = 0;
		for(int i = 0; i < len; ++i) b[i] = B[i], B[i] = 0;
	}

	void SQRT(int *a, int *b, int len) {
		static int A[N], B[N];
		for(int i = 0; i < len; ++i) A[i] = a[i], a[i] = 0;
		for(int i = 0; i < len; ++i) B[i] = b[i], b[i] = 0;
		B[0] = 1;
		for(int j = 2; j < len; j <<= 1) {
			for(int i = 0; i < (j << 1); ++i) a[i] = b[i] = 0;
			for(int i = 0; i < (j >> 1); ++i) a[i] = b[i] = B[i];
			mul(a, b, j);
			for(int i = 0; i < j; ++i) a[i] = (a[i] + A[i]) % mod;
			for(int i = 0; i < (j >> 1); ++i) b[i] = 2ll * B[i] % mod;
			for(int i = (j >> 1); i < j; ++i) b[i] = 0;
			INV(b, B, j << 1);
			for(int i = j; i < (j << 1); ++i) B[i] = 0;
			mul(a, B, j << 1);
			for(int i = 0; i < j; ++i) B[i] = a[i];
			for(int i = j; i < (j << 1); ++i) B[i] = 0;
		}
	}
	
	void init() {while(mx <= n + n) mx <<= 1;}
}F;
int n, m, cnt[N], inv[N];

int main() {
	n = read(); m = read(); F.n = m; F.init();
	inv[1] = 1;
	for(int i = 2; i <= m; ++i) inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
	for(int i = 1; i <= n; ++i) cnt[read()]++;
	for(int i = 1; i <= m; ++i)
		for(int j = 1; i * j <= m; ++j)
			F.f[i * j] = (1ll * F.f[i * j] + 1ll * cnt[i] * inv[j] % mod) % mod;
	F.EXP(F.f, F.g, F.mx);
	for(int i = 1; i <= m; ++i) printf("%d\n", F.g[i]);
	return 0;
}

\(END.\)

posted @ 2023-02-27 21:59  _zyc  阅读(45)  评论(0编辑  收藏  举报