LOJ #6503. 「雅礼集训 2018 Day4」Magic

多项式复习好题

题意

m m m种小球,每种小球有 a [ i ] a[i] a[i]个,一共有 ∑ a [ i ] = n \sum a[i] = n a[i]=n个小球
n n n个小球排成一排
问恰好有 k k k对相邻的小球种类相同的方案数

题解

对于64’

直接做恰好k对的不好做,我们可以先做出至少k对的来
设 f [ i ] 表 示 至 少 i 对 的 方 案 数 , g [ i ] 为 恰 好 i 对 的 方 案 数 设f[i]表示至少i对的方案数,g[i]为恰好i对的方案数 f[i]i,g[i]i
可 以 发 现 g [ i ] = f [ i ] − ∑ j = i + 1 n g [ j ] ∗ C j i 可以发现g[i]=f[i] - \sum\limits_{j=i+1}^n g[j]*C_j^i g[i]=f[i]j=i+1ng[j]Cji
f [ i ] = ∑ j = i n C j i ∗ g [ j ] f[i] = \sum\limits_{j=i}^n C_j^i*g[j] f[i]=j=inCjig[j]
根 据 二 项 式 反 演 g [ i ] = ∑ j = i n ( − 1 ) j − i ∗ C j i f [ j ] 根据二项式反演g[i]=\sum\limits_{j=i}^n (-1)^{j-i}*C_j^if[j] g[i]=j=in(1)jiCjif[j]

如何求出至少k对???
相邻的一对种类相同的球可以把这一对合成一个球
考虑DP
设 d p [ i ] [ j ] 表 示 前 i 种 颜 色 , 分 成 j 块 ( 相 邻 两 块 不 一 定 种 类 不 同 ) 设dp[i][j]表示前i种颜色,分成j块(相邻两块不一定种类不同) dp[i][j]i,j()
可 得 d p [ i ] [ j ] = ∑ d p [ i − 1 ] [ j − k ] ∗ C a [ i ] − 1 k − 1 ∗ C j k 可得dp[i][j] = \sum dp[i-1][j-k]*C_{a[i]-1}^{k-1}*C_j^k dp[i][j]=dp[i1][jk]Ca[i]1k1Cjk

意思就是把当前种类的球分成 k k k份,就是 ∗ C a [ i ] − 1 k − 1 *C_{a[i]-1}^{k-1} Ca[i]1k1在插入到块中就是 ∗ C j k *C_j^k Cjk

最 后 f [ i ] = d p [ m ] [ n − i ] , 再 通 过 二 项 式 反 演 求 出 g [ i ] 即 可 最后f[i]=dp[m][n-i],再通过二项式反演求出g[i]即可 f[i]=dp[m][ni],g[i]

这样就可以获得64’的好成绩

code:

#include<bits/stdc++.h>
#define mod 998244353
#define ll long long
using namespace std;
int n, a[100005];
ll f[2][100005], fac[100005], _fac[100005], g[100005];
ll qpow(ll x, ll y) {
	ll ret = 1;
	for(; y; y >>= 1, x = x * x % mod) if(y & 1) ret = ret * x % mod;
	return ret;
}
ll C(int x, int y) {
	if(x < y) return 0;
	return fac[x] * _fac[x - y] % mod * _fac[y] % mod;
}//组合数计算 
int main() {
	fac[0] = 1;
	for(int i = 1; i <= 100000; i ++) fac[i] = fac[i - 1] * i % mod;
	_fac[100000] = qpow(fac[100000], mod - 2);
	for(int i = 100000 - 1; i >= 0; i --) _fac[i] = _fac[i + 1] * (i + 1) % mod;//预处理fac和invfac用于算组合数 
	int q, Q;
	scanf("%d%d%d", &n, &Q, &q);//这里的n,m和题目给出的是不同的 
	for(int i = 1; i <= n; i ++) scanf("%d", &a[i]);
	int m = a[1];
	for(int i = 1; i <= m; i ++) f[1][i] = C(a[1] - 1, i - 1);
	
	for(int i = 2; i <= n; i ++) {
		memset(f[i&1], 0, sizeof f[i&1]);
		m += a[i];
		for(int j = 1; j <= m; j ++) {
			for(int k = 1; k <= min(j, a[i]); k ++){
				f[i&1][j] += f[(i - 1)&1][j - k] * C(a[i] - 1, k - 1) % mod * C(j, k) % mod, f[i&1][j] %= mod;//DP 
			}
		}
	} 
	for(int i = 0; i <= m; i ++) g[i] = f[n&1][m - i];
		ll ans = 0;
		for(int i = q; i <= m; i ++) {//二项式反演 
		if((i - q) & 1) ans -= C(i, q) * g[i] % mod, ans = (ans + mod) % mod; 
		else ans += C(i, q) * g[i] % mod, ans %= mod;
	}
	printf("%lld\n", ans);//输出 

	return 0;
}

对于100’

我们再看一下转移的这条式子

d p [ i ] [ j ] = ∑ d p [ i − 1 ] [ j − k ] ∗ C a [ i ] − 1 k − 1 ∗ C j k dp[i][j] = \sum dp[i-1][j-k]*C_{a[i]-1}^{k-1}*C_j^k dp[i][j]=dp[i1][jk]Ca[i]1k1Cjk

d p [ i ] [ j ] 为 f [ j ] , d p [ i − 1 ] [ j − k ] 为 g [ j − k ] dp[i][j]为f[j],dp[i-1][j-k]为g[j-k] dp[i][j]f[j],dp[i1][jk]g[jk]
可得

f [ j ] = ∑ g [ j − k ] ∗ C a [ i ] − 1 k − 1 ∗ C j k f[j] = \sum g[j-k]*C_{a[i]-1}^{k-1}*C_j^k f[j]=g[jk]Ca[i]1k1Cjk

按照一般的讨论都是要把组合数拆开
肯定先拆 C j k C_j^k Cjk因为它和 j , k j,k j,k都有关

f [ j ] = ∑ g [ j − k ] ∗ C a [ i ] − 1 k − 1 ∗ j ! ( j − k ) ! k ! f[j] = \sum g[j-k]*C_{a[i]-1}^{k-1}*\frac{j!}{(j-k)!k!} f[j]=g[jk]Ca[i]1k1(jk)!k!j!

j ! j! j!除过去

1 j ! f [ j ] = ∑ 1 ( j − k ) ! g [ j − k ] ∗ C a [ i ] − 1 k − 1 ∗ 1 k ! \frac{1}{j!}f[j] = \sum \frac{1}{(j-k)!}g[j-k]*C_{a[i]-1}^{k-1}*\frac{1}{k!} j!1f[j]=(jk)!1g[jk]Ca[i]1k1k!1

设 f f [ j ] = 1 j ! f [ j ] , g g [ i ] [ k ] = C a [ i ] − 1 k − 1 ∗ 1 k ! 设ff[j]=\frac{1}{j!}f[j],gg[i][k]=C_{a[i]-1}^{k-1}*\frac{1}{k!} ff[j]=j!1f[j],gg[i][k]=Ca[i]1k1k!1

可 得 f f [ j ] = ∑ f f [ j − k ] ∗ g g [ i ] [ k ] 可得ff[j] = \sum ff[j-k]*gg[i][k] ff[j]=ff[jk]gg[i][k]

这明显就是一个卷积的形式(分治FFT)

没有一眼看出来的可以考虑生成函数

记 F i = ∑ j = 1 ∞ C a [ i ] − 1 j − 1 x j j ! 记F_i=\sum\limits_{j=1}^{\infty}C_{a[i]-1}^{j-1}\frac{x^j}{j!} Fi=j=1Ca[i]1j1j!xj

明 显 答 案 就 是 a n s = ∏ F i 明显答案就是ans=\prod F_i ans=Fi

这个东西也是直接分治+NTT就可以了
大概是O(nlog^2n)
code:

#include<bits/stdc++.h>
#define N 400005
#define mod 998244353
#define G 3 
using namespace std;
inline int add(int x, int y) {
	x += y; return (x >= mod)? x - mod : x;
}
inline int sub(int x, int y) {
	x -= y; return (x < 0)? x + mod : x;
}
inline int mul(int x, int y) {
	return 1ll * x * y % mod;
}
inline int qpow(int x, int y) {
	if(y < 0) {y = -y; x = qpow(x, mod - 2);}
	int ans = 1;
	for(; y; y >>= 1, x = mul(x, x)) if(y & 1) ans = mul(ans, x);
	return ans;
}
int fac[N], invfac[N], rev[N];
void init() {
	fac[0] = 1;
	for(int i = 1; i < N; i ++) fac[i] = mul(fac[i - 1], i);
	invfac[N - 1] = qpow(fac[N - 1], mod - 2);
	for(int i = N - 2; i >= 0; i --) invfac[i] = mul(invfac[i + 1], i + 1);
}
inline int C(int x, int y) {
	if(x < y || y < 0) return 0;
	return mul(fac[x], mul(invfac[x - y], invfac[y]));
}
void ntt(int *a, int len, int o) {
	for(int i = 1; i < len; i ++) rev[i] = (rev[i >> 1] >> 1) | (((i & 1) * len) >> 1);
	for(int i = 1; i < len; i ++) if(i < rev[i]) swap(a[i], a[rev[i]]);
	for(int i = 1; i < len; i <<= 1) {
		int wn = qpow(G, o * (mod - 1) / (i << 1));
		for(int j = 0; j < len; j += i << 1) {
			int wnk = 1;
			for(int k = j; k < j + i; k ++, wnk = mul(wnk, wn)) {
				int X = a[k], Y = mul(a[k + i], wnk);
				a[k] = add(X, Y);
				a[k + i] = sub(X, Y);
			}
		}
	}
	if(o == -1) {
		int invlen = qpow(len, mod - 2);
		for(int i = 0; i < len; i ++) a[i] = mul(a[i], invlen);
	}
}
int fa[N], ga[N];
void mult(vector<int> &ff, vector<int> &gg, vector<int> &fg) {
	for(int i = 0; i < ff.size(); i ++) fa[i] = ff[i];
	for(int i = 0; i < gg.size(); i ++) ga[i] = gg[i];
	int len = 1, l = ff.size() + gg.size() - 1;
	for(;len < l; len <<= 1);
	
	for(int i = ff.size(); i < len; i ++) fa[i] = 0;
	for(int i = gg.size(); i < len; i ++) ga[i] = 0;
	
	ntt(fa, len, 1), ntt(ga, len, 1);
	for(int i = 0; i < len; i ++) fa[i] = mul(fa[i], ga[i]);
	ntt(fa, len, -1);
	
	for(int i = 0; i < l; i ++) fg.push_back(fa[i]);
}//多项式乘法板子 
vector<int> ff[N], f[N];
void solve(int l, int r, int rt) {//分治+NTT 
	if(l == r) {
		ff[rt] = f[l];
		return;
	}
	int mid = (l + r) >> 1;
	solve(l, mid, rt << 1);
	solve(mid + 1, r, rt << 1 | 1);
	mult(ff[rt << 1], ff[rt << 1 | 1], ff[rt]);
}
int n, m, Q, q, g[N], a[N];
int main() {
	init();
	scanf("%d%d%d", &n, &Q, &q);
	for(int i = 1; i <= n; i ++) {
		scanf("%d", &a[i]); m += a[i];
		f[i].push_back(0);
		for(int j = 1; j <= a[i]; j ++) f[i].push_back(mul(C(a[i] - 1, j - 1), invfac[j]));
	}
	solve(1, n, 1);
	for(int i = 0; i <= m; i ++) g[i] = mul(ff[1][m - i], fac[m - i]);
	int ans = 0;
	for(int i = q; i <= m; i ++) {//二项式反演 
		if((i - q) & 1) ans = sub(ans, mul(C(i, q), g[i]));
		else ans = add(ans, mul(C(i, q), g[i]));
	}
	printf("%d\n", ans);
	return 0;
}

考虑这题的加强版,多次询问k

因为每次求k的时候都要二项式反演一遍,所以时间复杂度会变成O(nq)

让我们再看一下二项式反演的那条式子

g [ i ] = ∑ j = i n ( − 1 ) j − i ∗ C j i f [ j ] g[i]=\sum\limits_{j=i}^n (-1)^{j-i}*C_j^if[j] g[i]=j=in(1)jiCjif[j]

g , f g,f g,f反过来一下,方便看

f [ i ] = ∑ j = i n ( − 1 ) j − i ∗ C j i g [ j ] f[i]=\sum\limits_{j=i}^n (-1)^{j-i}*C_j^ig[j] f[i]=j=in(1)jiCjig[j]

按照套路,先把组合数拆开

f [ i ] = ∑ j = i n ( − 1 ) j − i ∗ j ! ( j − i ) ! i ! g [ j ] f[i]=\sum\limits_{j=i}^n (-1)^{j-i}*\frac{j!}{(j-i)!i!}g[j] f[i]=j=in(1)ji(ji)!i!j!g[j]

把i!乘过去

i ! f [ i ] = ∑ j = i n ( − 1 ) j − i ( j − i ) ! ∗ j ! g [ j ] i!f[i]=\sum\limits_{j=i}^n \frac{(-1)^{j-i}}{(j-i)!}*j!g[j] i!f[i]=j=in(ji)!(1)jij!g[j]

再按照套路,将g反转
即设 g ′ [ j ] = g [ n − j ] g'[j]=g[n-j] g[j]=g[nj]

i ! f [ i ] = ∑ n − j = i n ( − 1 ) n − j − i ( n − j − i ) ! ∗ ( n − j ) ! g ′ [ j ] i!f[i]=\sum\limits_{n-j=i}^n \frac{(-1)^{n-j-i}}{(n-j-i)!}*(n-j)!g'[j] i!f[i]=nj=in(nji)!(1)nji(nj)!g[j]

i ! f [ i ] = ∑ j = n − i n ( − 1 ) n − j − i ( n − j − i ) ! ∗ ( n − j ) ! g ′ [ j ] i!f[i]=\sum\limits_{j=n-i}^n \frac{(-1)^{n-j-i}}{(n-j-i)!}*(n-j)!g'[j] i!f[i]=j=nin(nji)!(1)nji(nj)!g[j]

再把中间那坨反转一下(即n-i=i)

( n − i ) ! f [ n − i ] = ∑ j = 0 i ( − 1 ) i − j ( i − j ) ! ∗ ( n − j ) ! g ′ [ j ] (n-i)!f[n-i]=\sum\limits_{j=0}^i \frac{(-1)^{i-j}}{(i-j)!}*(n-j)!g'[j] (ni)!f[ni]=j=0i(ij)!(1)ij(nj)!g[j]

再按照套路,将f反转
即设 f ′ [ i ] = f [ n − i ] f'[i]=f[n-i] f[i]=f[ni]

( n − i ) ! f ′ [ i ] = ∑ j = 0 i ( − 1 ) i − j ( i − j ) ! ∗ ( n − j ) ! g ′ [ j ] (n-i)!f'[i]=\sum\limits_{j=0}^i \frac{(-1)^{i-j}}{(i-j)!}*(n-j)!g'[j] (ni)!f[i]=j=0i(ij)!(1)ij(nj)!g[j]

h [ i ] = ( n − i ) ! f ′ [ i ] , f f [ i ] = ( − 1 ) i i ! , g g [ i ] = ( n − i ) ! g ′ [ i ] h[i]=(n-i)!f'[i],ff[i]=\frac{(-1)^i}{i!},gg[i]=(n-i)!g'[i] h[i]=(ni)!f[i],ff[i]=i!(1)i,gg[i]=(ni)!g[i]

然 后 发 现 h = f f ∗ g g 然后发现 h=ff*gg h=ffgg

xjb卷一下就行了QWQ
code:

#include<bits/stdc++.h>
#define N 4000005
#define mod 998244353
#define G 3 
using namespace std;
inline int add(int x, int y) {
	x += y; return (x >= mod)? x - mod : x;
}
inline int sub(int x, int y) {
	x -= y; return (x < 0)? x + mod : x;
}
inline int mul(int x, int y) {
	return 1ll * x * y % mod;
}
inline int qpow(int x, int y) {
	if(y < 0) {y = -y; x = qpow(x, mod - 2);}
	int ans = 1;
	for(; y; y >>= 1, x = mul(x, x)) if(y & 1) ans = mul(ans, x);
	return ans;
}
int fac[N], invfac[N], rev[N];
void init() {
	fac[0] = 1;
	for(int i = 1; i < N; i ++) fac[i] = mul(fac[i - 1], i);
	invfac[N - 1] = qpow(fac[N - 1], mod - 2);
	for(int i = N - 2; i >= 0; i --) invfac[i] = mul(invfac[i + 1], i + 1);
}
inline int C(int x, int y) {
	if(x < y || y < 0) return 0;
	return mul(fac[x], mul(invfac[x - y], invfac[y]));
}
void ntt(int *a, int len, int o) {
	for(int i = 1; i < len; i ++) rev[i] = (rev[i >> 1] >> 1) | (((i & 1) * len) >> 1);
	for(int i = 1; i < len; i ++) if(i < rev[i]) swap(a[i], a[rev[i]]);
	for(int i = 1; i < len; i <<= 1) {
		int wn = qpow(G, o * (mod - 1) / (i << 1));
		for(int j = 0; j < len; j += i << 1) {
			int wnk = 1;
			for(int k = j; k < j + i; k ++, wnk = mul(wnk, wn)) {
				int X = a[k], Y = mul(a[k + i], wnk);
				a[k] = add(X, Y);
				a[k + i] = sub(X, Y);
			}
		}
	}
	if(o == -1) {
		int invlen = qpow(len, mod - 2);
		for(int i = 0; i < len; i ++) a[i] = mul(a[i], invlen);
	}
}
int fa[N], ga[N];
void mult(vector<int> &ff, vector<int> &gg, vector<int> &fg) {
	for(int i = 0; i < ff.size(); i ++) fa[i] = ff[i];
	for(int i = 0; i < gg.size(); i ++) ga[i] = gg[i];
	int len = 1, l = ff.size() + gg.size() - 1;
	for(;len < l; len <<= 1);
	
	for(int i = ff.size(); i < len; i ++) fa[i] = 0;
	for(int i = gg.size(); i < len; i ++) ga[i] = 0;
	
	ntt(fa, len, 1), ntt(ga, len, 1);
	for(int i = 0; i < len; i ++) fa[i] = mul(fa[i], ga[i]);
	ntt(fa, len, -1);
	
	for(int i = 0; i < l; i ++) fg.push_back(fa[i]);
}
vector<int> ff[N], f[N];
void solve(int l, int r, int rt) {
	if(l == r) {
		ff[rt] = f[l];
		return;
	}
	int mid = (l + r) >> 1;
	solve(l, mid, rt << 1);
	solve(mid + 1, r, rt << 1 | 1);
	mult(ff[rt << 1], ff[rt << 1 | 1], ff[rt]);
}
int n, m, Q, q, g[N], a[N];
int main() {
	init();
	scanf("%d", &n);
	for(int i = 1; i <= n; i ++) {
		scanf("%d", &a[i]); m += a[i];
		f[i].push_back(0);
		for(int j = 1; j <= a[i]; j ++) f[i].push_back(mul(C(a[i] - 1, j - 1), invfac[j]));
	}
	solve(1, n, 1);
	for(int i = 0; i <= m; i ++) g[i] = mul(ff[1][m - i], fac[m - i]);
	//--------以下部分求h-------- 
	memset(fa, 0, sizeof fa);
	memset(ga, 0, sizeof ga);
	for(int i = 0; i <= m; i ++) {
		if(i & 1) fa[i] = sub(0, invfac[i]);
		else fa[i] = invfac[i];
	} 
	for(int i = 0; i <= m; i ++) {
		ga[i] = mul(fac[m - i], g[m - i]);
	}
	
	int len = 1;
	for(;len < m + m + 1;) len <<= 1; 
	ntt(fa, len, 1), ntt(ga, len, 1);
	for(int i = 0; i < len; i ++) fa[i] = mul(fa[i], ga[i]);
	ntt(fa, len, -1);//xjb卷起来 
	
	scanf("%d", &Q);
	while(Q --) {	
		scanf("%d", &q);
		printf("%d\n", mul(invfac[q], fa[m - q]));//记得乘1/q! 
	}
	
	return 0;
}

emmmmmm…多项式的香气
感觉越来越喜欢多项式了
这真是道有意思的题QWQ

posted @ 2019-12-11 20:47  lahlah  阅读(17)  评论(0编辑  收藏  举报