[CF755G] PolandBall and Many Other Balls

\(\text{Problem}:\)PolandBall and Many Other Balls

\(\text{Preface}:\)

这是一道有着多种不同做法的经典好题。由于笔者水平有限,仅在下文介绍三种主流做法。

\(\text{Solution 1}:\) 组合容斥

\(f_{k}\) 表示取出 \(k\) 组的方案数。考虑其组合意义,从 \(k\) 组中选 \(i\) 组包含两个球,将包含两个球的位置绑定,就只剩 \(n-i\) 个位置放 \(k\) 个球。故有:

\[f_{k}=\sum\limits_{i=0}^{k}\binom{k}{i}\binom{n-i}{k} \]

发现这个式子难以展开推导,故考虑改变其组合意义。发现 \(\binom{k}{i}\binom{n-i}{k}\) 等价于一排有 \(n\) 个球,在前 \(k\) 个中选若干个,在剩余的球中任意选 \(k\) 个的方案数。根据定义,第一次选择的球和第二次选择的 \(k\) 个球是没有交集的。这提示我们容斥求解。

\(g_{i}\) 表示恰好有 \(i\) 个球被重复选到的方案数,\(h_{i}\) 表示钦定有 \(i\) 个球被重复选到的方案数。考虑求出 \(h_{i}\)。钦定前 \(k\) 个中有 \(i\) 个被重复选择,剩下 \(k-i\) 个没有选择限制,而 \(n-i\) 个球中只能选出 \(k-i\) 个。故有:

\[\begin{aligned} g_{i}&=\sum\limits_{j=i}^{k}(-1)^{j-i}\binom{j}{i}h_{j}\\ &=\sum\limits_{j=i}^{k}(-1)^{j-i}\binom{j}{i}\binom{k}{j}2^{k-j}\binom{n-j}{k-j} \end{aligned} \]

显然 \(f_{k}=g_{0}\),得到:

\[\begin{aligned} f_{k}&=\sum\limits_{j=0}^{k}(-1)^{j}\binom{k}{j}2^{k-j}\binom{n-j}{k-j}\\ &=\frac{k!}{(n-k)!}\sum\limits_{j=0}^{k}\frac{(-1)^{j}(n-j)!}{j!}\cdot\frac{2^{k-j}}{(k-j)!(k-j)!}\\ &=k!\cdot n^{\underline{k}}\sum\limits_{j=0}^{k}\frac{(-1)^{j}}{j!\cdot n^{\underline{j}}}\cdot\frac{2^{k-j}}{(k-j)!(k-j)!} \end{aligned} \]

注意到 \(n^{\underline{j}}\) 中如果存在 \(998244353\),那么 \(n^{\underline{k}}\) 中一定存在 \(998244353\)。故以 \(n^{\underline{k}}\) 在模 \(998244353\) 意义下是否为 \(0\) 分段,做两次 \(\text{NTT}\) 即可。时间复杂度 \(O(n\log n)\)

\(\text{Code 1}:\)

#include <bits/stdc++.h>
#pragma GCC optimize(3)
//#define int long long
#define ri register
#define mk make_pair
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
#define is insert
#define es erase
#define vi vector<int>
#define vpi vector<pair<int,int>>
using namespace std; const int N=135010, Mod=998244353;
inline int read()
{
	int s=0, w=1; ri char ch=getchar();
	while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
	while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch^48), ch=getchar();
	return s*w;
}
int n,K;
int rev[N],r[24][2],fac[N+5],inv[N+5],pw[N+5],ifac[N+5],tfac[N+5];
inline int ksc(int x,int p) { int res=1; for(;p;p>>=1, x=1ll*x*x%Mod) if(p&1) res=1ll*res*x%Mod; return res; }
inline void Get_Rev(int T) { for(ri int i=0;i<T;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(T>>1):0); }
inline void DFT(int T,vector<int> &s,int type)
{
	for(ri int i=0;i<T;i++) if(rev[i]<i) swap(s[i],s[rev[i]]);
	for(ri int i=2,cnt=1;i<=T;i<<=1,cnt++)
	{
		int wn=r[cnt][type];
		for(ri int j=0,mid=(i>>1);j<T;j+=i)
		{
			for(ri int k=0,w=1;k<mid;k++,w=1ll*w*wn%Mod)
			{
				int x=s[j+k], y=1ll*w*s[j+mid+k]%Mod;
				s[j+k]=(x+y)%Mod;
				s[j+mid+k]=x-y;
				if(s[j+mid+k]<0) s[j+mid+k]+=Mod;
			}
		}
	}
	if(!type) for(ri int i=0,inv=ksc(T,Mod-2);i<T;i++) s[i]=1ll*s[i]*inv%Mod;
}
inline void NTT(int n,int m,vector<int> &A,vector<int> B)
{
	int len=n+m;
	int T=1;
	while(T<=len) T<<=1;
	Get_Rev(T);
	A.resize(T), B.resize(T);
	for(ri int i=n+1;i<T;i++) A[i]=0;
	for(ri int i=m+1;i<T;i++) B[i]=0;
	DFT(T,A,1), DFT(T,B,1);
	for(ri int i=0;i<T;i++) A[i]=1ll*A[i]*B[i]%Mod;
	DFT(T,A,0);
}
signed main()
{
	fac[0]=1;
	for(ri int i=1;i<=N;i++) fac[i]=1ll*fac[i-1]*i%Mod;
	inv[N]=ksc(fac[N],Mod-2);
	for(ri int i=N;i;i--) inv[i-1]=1ll*inv[i]*i%Mod;
	pw[0]=1;
	for(ri int i=1;i<=N;i++) pw[i]=(pw[i-1]<<1)%Mod;
	r[23][1]=ksc(3,119), r[23][0]=ksc(ksc(3,Mod-2),119);
	for(ri int i=22;~i;i--) r[i][0]=1ll*r[i+1][0]*r[i+1][0]%Mod, r[i][1]=1ll*r[i+1][1]*r[i+1][1]%Mod;
	n=read(), K=read();
	ifac[0]=tfac[0]=1;
	for(ri int i=1;i<=n&&i<=N;i++)
	{
		ifac[i]=1ll*ifac[i-1]*((n-i+1)%Mod)%Mod;
		tfac[i]=tfac[i-1];
		if((n-i+1)%Mod) tfac[i]=1ll*tfac[i]*((n-i+1)%Mod)%Mod;
	}
	vector<int> A,B;
	int ct=K+1;
	for(ri int i=1;i<=K;i++) if(!ifac[i]) { ct=i; break; }
	A.resize(ct), B.resize(ct);
	for(ri int i=0;i<ct;i++)
	{
		if(i&1) A[i]=Mod-1ll*inv[i]*ksc(ifac[i],Mod-2)%Mod;
		else A[i]=1ll*inv[i]*ksc(ifac[i],Mod-2)%Mod;
		B[i]=1ll*pw[i]*inv[i]%Mod*inv[i]%Mod;
	}
	NTT(ct,ct,A,B);
	for(ri int i=1;i<ct;i++)
		printf("%d ",1ll*A[i]*ifac[i]%Mod*fac[i]%Mod);
	vector<int>().swap(A);
	A.resize(K+1), B.resize(K+1);
	for(ri int i=ct;i<=K;i++)
	{
		if(i&1) A[i]=Mod-1ll*inv[i]*ksc(tfac[i],Mod-2)%Mod;
		else A[i]=1ll*inv[i]*ksc(tfac[i],Mod-2)%Mod;
		A[i]%=Mod;
		B[i]=1ll*pw[i]*inv[i]%Mod*inv[i]%Mod;
	}
	NTT(K,K,A,B);
	for(ri int i=ct;i<=K;i++)
		printf("%d ",1ll*A[i]*tfac[i]%Mod*fac[i]%Mod);
	puts("");
	return 0;
}

\(\text{Solution 2}:\) 特征方程

\(f_{i,j}\) 表示前 \(i\) 个球分成 \(j\) 组的方案数,有转移:

\[f_{i,j}=f_{i-1,j}+f_{i-1,j-1}+f_{i-2,j-1} \]

\(F_{i}(x)\) 表示序列 \(f_{i}\)\(\text{OGF}\),有:

\[F_{i}(x)=(x+1)F_{i-1}(x)+xF_{i-2}(x) \]

该递推式的特征方程为:

\[z^{2}-(x+1)z-x=0 \]

设两解 \(z_{1},z_{2},z_{1}\geq z_{2}\),由求根公式,有:

\[z_{1}=\frac{x+1+\sqrt{x^{2}+6x+1}}{2},z_{2}=\frac{x+1-\sqrt{x^{2}+6x+1}}{2} \]

现在设 \(F_{n}(x)=c_{1}z_{1}^{n}+c_{2}z_{2}^{n}\),由 \(F_{0}(x)=1,F_{1}(x)=1+x\),得到:

\[c_{1}=\frac{z_{1}}{\sqrt{x^{2}+6x+1}},c_{2}=\frac{z_{2}}{\sqrt{x^{2}+6x+1}}\\ F_{n}(x)=\frac{z_{1}^{n+1}-z_{2}^{n+1}}{\sqrt{x^{2}+6x+1}} \]

发现 \(z_{2}\) 的常数项为 \(0\),即 \(z_{2}^{n+1}\equiv 0\pmod {x^{n+1}}\),有:

\[F_{n}(x)=\frac{z_{1}^{n+1}}{\sqrt{x^{2}+6x+1}}=\frac{(x+1+\sqrt{x^{2}+6x+1})^{n+1}}{2^{n+1}\sqrt{x^{2}+6x+1}} \]

利用多项式快速幂即可在 \(O(n\log n)\) 的时间复杂度内解决本题。

\(\text{Code 2}:\)

#include <bits/stdc++.h>
#pragma GCC optimize(3)
//#define int long long
#define ri register
#define mk make_pair
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
#define is insert
#define es erase
#define vi vector<int>
#define vpi vector<pair<int,int>>
using namespace std; const int N=135010, Mod=998244353;
inline int read()
{
	int s=0, w=1; ri char ch=getchar();
	while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
	while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch^48), ch=getchar();
	return s*w;
}
int n,K;
int rev[N],r[24][2],iiv[N+5];
inline int ksc(int x,int p) { int res=1; for(;p;p>>=1, x=1ll*x*x%Mod) if(p&1) res=1ll*res*x%Mod; return res; }
inline void Get_Rev(int T) { for(ri int i=0;i<T;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(T>>1):0); }
inline void DFT(int T,vector<int> &s,int type)
{
	for(ri int i=0;i<T;i++) if(rev[i]<i) swap(s[i],s[rev[i]]);
	for(ri int i=2,cnt=1;i<=T;i<<=1,cnt++)
	{
		int wn=r[cnt][type];
		for(ri int j=0,mid=(i>>1);j<T;j+=i)
		{
			for(ri int k=0,w=1;k<mid;k++,w=1ll*w*wn%Mod)
			{
				int x=s[j+k], y=1ll*w*s[j+mid+k]%Mod;
				s[j+k]=(x+y)%Mod;
				s[j+mid+k]=x-y;
				if(s[j+mid+k]<0) s[j+mid+k]+=Mod;
			}
		}
	}
	if(!type) for(ri int i=0,inv=ksc(T,Mod-2);i<T;i++) s[i]=1ll*s[i]*inv%Mod;
}
inline void NTT(int n,int m,vector<int> &A,vector<int> B)
{
	int len=n+m;
	int T=1;
	while(T<=len) T<<=1;
	Get_Rev(T);
	A.resize(T), B.resize(T);
	DFT(T,A,1), DFT(T,B,1);
	for(ri int i=0;i<T;i++) A[i]=1ll*A[i]*B[i]%Mod;
	DFT(T,A,0);
}
void GetInv(int n,vector<int> &F,vector<int> G)
{
	if(n==1) { F[0]=ksc(G[0],Mod-2); return; }
	GetInv((n+1)/2,F,G);
	vector<int> A,B;
	int T=1;
	while(T<=n+n) T<<=1;
	Get_Rev(T);
	A.resize(T), B.resize(T);
	for(ri int i=0;i<n;i++) A[i]=F[i], B[i]=G[i];
	DFT(T,A,1), DFT(T,B,1);
	for(ri int i=0;i<T;i++) A[i]=(2ll*A[i]%Mod-1ll*B[i]*A[i]%Mod*A[i]%Mod+Mod)%Mod;
	DFT(T,A,0);
	for(ri int i=0;i<n;i++) F[i]=A[i];
}
void GetDao(int n,vector<int> &A,vector<int> B)
{
	for(ri int i=0;i<n-1;i++) A[i]=1ll*(i+1)*B[i+1]%Mod;
	A[n-1]=0;
}
void GetJi(int n,vector<int> &A,vector<int> B)
{
	for(ri int i=1;i<n;i++) A[i]=1ll*B[i-1]*iiv[i]%Mod;
	A[0]=0;
}
void GetLn(int n,vector<int> &F,vector<int> G)
{
	vector<int> A,B;
	A.resize(n), B.resize(n);
	GetDao(n,A,G);
	GetInv(n,B,G);
	NTT(n,n,A,B);
	GetJi(n,F,A);
}
void GetExp(int n,vector<int> &F,vector<int> G)
{
	if(n==1) { F[0]=1; return; }
	GetExp((n+1)/2,F,G);
	vector<int> C;
	C.resize(n);
	GetLn(n,C,F);
	vector<int> A,B;
	int T=1;
	while(T<=n+n) T<<=1;
	Get_Rev(T);
	A.resize(T), B.resize(T);
	for(ri int i=0;i<n;i++) A[i]=F[i], B[i]=(G[i]-C[i]+Mod)%Mod; B[0]++;
	DFT(T,A,1), DFT(T,B,1);
	for(ri int i=0;i<T;i++) A[i]=1ll*A[i]*B[i]%Mod;
	DFT(T,A,0);
	for(ri int i=0;i<n;i++) F[i]=A[i];
}
struct Node { int x,y; }; int I2;
inline Node operator * (Node a,Node b)
{
	int w1=(1ll*a.x*b.x%Mod+1ll*a.y*b.y%Mod*I2%Mod)%Mod;
	int w2=(1ll*a.x*b.y%Mod+1ll*a.y*b.x%Mod)%Mod;
	return (Node){w1,w2};
}
inline Node KSC(Node x,int p) { Node res=(Node){1,0}; for(;p;p>>=1, x=x*x) if(p&1) res=res*x; return res; }
inline bool Check(int x) { return ksc(x,(Mod-1)/2)==1; }
inline int Rand()
{
	int w=1ll*rand()*rand()%Mod;
	w+=rand()-rand();
	w=(w%Mod+Mod)%Mod;
	return w;
}
inline int Cipolla(int n)
{
	if(!n) return 0;
	int a=0;
	while(Check((1ll*a*a%Mod-n+Mod)%Mod)) a=Rand();
	I2=(1ll*a*a%Mod-n+Mod)%Mod;
	int X1=KSC((Node){a,1},(Mod+1)/2).x;
	return min(X1,Mod-X1);
}
void Getsqrt(int n,vector<int> &F,vector<int> G)
{
	if(n==1) { F[0]=Cipolla(G[0]); return; }
	Getsqrt((n+1)/2,F,G);
	vector<int> A,B;
	A.resize(n), B.resize(n);
	GetInv(n,A,F);
	for(ri int i=0;i<n;i++) B[i]=G[i];
	NTT(n,n,A,B);
	for(ri int i=0,inv2=(Mod+1)/2;i<n;i++) F[i]=1ll*(F[i]+A[i])*inv2%Mod;
}
signed main()
{
	srand(time(NULL));
	iiv[1]=1;
	for(ri int i=2;i<=N;i++) iiv[i]=1ll*(Mod-Mod/i)*iiv[Mod%i]%Mod;
	r[23][1]=ksc(3,119), r[23][0]=ksc(ksc(3,Mod-2),119);
	for(ri int i=22;~i;i--) r[i][0]=1ll*r[i+1][0]*r[i+1][0]%Mod, r[i][1]=1ll*r[i+1][1]*r[i+1][1]%Mod;
	n=read(), K=read(); K++;
	vector<int> a,F;
	a.resize(K), F.resize(K);
	a[0]=1, a[1]=6;
	if(K>2) a[2]=1;
	Getsqrt(K,a,a);
	vector<int> G; G=a;
	vector<int> H=G;
	for(ri int i=0;i<K;i++) G[i]=0;
	GetInv(K,G,H);
	a[0]++, a[0]%=Mod, a[1]++, a[1]%=Mod;
	int inv2=(Mod+1)/2;
	int w=1ll*a[0]*inv2%Mod;
	for(ri int i=0,inv=ksc(w,Mod-2);i<K;i++) a[i]=1ll*a[i]*inv%Mod*inv2%Mod;
	GetLn(K,F,a);
	for(ri int i=0;i<K;i++) F[i]=1ll*F[i]*((n+1)%Mod)%Mod;
	GetExp(K,F,F);
	for(ri int i=0,gg=ksc(w,(n+1)%(Mod-1));i<K;i++) F[i]=1ll*F[i]*gg%Mod;
	NTT(K,K,F,G);
	for(ri int i=1;i<K;i++) printf("%d ",(i<=n)?F[i]:0);
	puts("");
	return 0;
}

\(\text{Solution 3}:\) 倍增

\(f_{i,j}\) 表示前 \(i\) 个球分成 \(j\) 组的方案数。与 \(\text{Solution 2}\) 不同的,我们考虑另外一种转移。每次转移可以看作两排球并在一起,根据并的位置是否将一组包含两个球划开分类讨论,有转移:

\[f_{x+y,i}=\sum\limits_{j=0}^{i}f_{x,j}\times f_{y,i-j}+\sum\limits_{j=0}^{i-1}f_{x-1,j}\times f_{y-1,i-j-1} \]

\(F_{i}(x)\) 表示序列 \(f_{i}\)\(\text{OGF}\),有:

\[F_{i+j}(x)=F_{i}(x)F_{j}(x)+xF_{i-1}(x)F_{j-1}(x) \]

发现 \(i+j\) 的转移中还需要求出 \(i-1,j-1\),而 \(i+j-1\) 的转移中还需要求出 \(i-2,j-2\)(但求 \(i+j-2\) 也只需 \(i-2,j-2\)),故维护 \(F_{2^{k}},F_{2^{k}-1},F_{2^{k}-2}\) 倍增求解即可。总时间复杂度 \(O(n\log^{2}n)\)

\(\text{Code 3}:\)

#include <bits/stdc++.h>
#pragma GCC optimize(3)
//#define int long long
#define ri register
#define mk make_pair
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
#define is insert
#define es erase
#define vi vector<int>
#define vpi vector<pair<int,int>>
using namespace std; const int N=135010, Mod=998244353;
inline int read()
{
	int s=0, w=1; ri char ch=getchar();
	while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
	while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch^48), ch=getchar();
	return s*w;
}
int n,K;
struct DP
{
	vector<int> a[3];
}f,g;
int rev[N],r[24][2];
inline int ksc(int x,int p) { int res=1; for(;p;p>>=1, x=1ll*x*x%Mod) if(p&1) res=1ll*res*x%Mod; return res; }
inline void Get_Rev(int T) { for(ri int i=0;i<T;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(T>>1):0); }
inline void DFT(int T,vector<int> &s,int type)
{
	for(ri int i=0;i<T;i++) if(rev[i]<i) swap(s[i],s[rev[i]]);
	for(ri int i=2,cnt=1;i<=T;i<<=1,cnt++)
	{
		int wn=r[cnt][type];
		for(ri int j=0,mid=(i>>1);j<T;j+=i)
		{
			for(ri int k=0,w=1;k<mid;k++,w=1ll*w*wn%Mod)
			{
				int x=s[j+k], y=1ll*w*s[j+mid+k]%Mod;
				s[j+k]=(x+y)%Mod;
				s[j+mid+k]=x-y;
				if(s[j+mid+k]<0) s[j+mid+k]+=Mod;
			}
		}
	}
	if(!type) for(ri int i=0,inv=ksc(T,Mod-2);i<T;i++) s[i]=1ll*s[i]*inv%Mod;
}
inline void Merge(DP &x,DP y)
{
	int len=(int)x.a[0].size()+(int)y.a[0].size()-1;
	int lim=len;
	len=min(len,K);
	lim=min(lim,K+K);
	vector<int> w[6];
	for(ri int i=0;i<3;i++) w[i]=x.a[i];
	for(ri int i=3;i<6;i++) w[i]=y.a[i-3];
	int T=1;
	while(T<=lim) T<<=1;
	Get_Rev(T);
	for(ri int i=0;i<6;i++) w[i].resize(T), DFT(T,w[i],1);
	for(ri int i=0;i<T;i++)
	{
		int h0=w[0][i];
		int h1=w[1][i];
		int h2=w[2][i];
		int h3=w[3][i];
		int h4=w[4][i];
		int h5=w[5][i];
		w[0][i]=1ll*h0*h3%Mod;
		w[1][i]=1ll*h0*h4%Mod;
		w[2][i]=1ll*h1*h4%Mod;
		w[3][i]=1ll*h1*h5%Mod;
		w[4][i]=1ll*h2*h5%Mod;
	}
	for(ri int i=0;i<5;i++) DFT(T,w[i],0);
	for(ri int i=0;i<3;i++) x.a[i].resize(len+1);
	for(ri int i=0;i<=len;i++)
	{
		x.a[0][i]=w[0][i];
		x.a[1][i]=w[1][i];
		x.a[2][i]=w[2][i];
		if(i) x.a[0][i]=(x.a[0][i]+w[2][i-1])%Mod, x.a[1][i]=(x.a[1][i]+w[3][i-1])%Mod, x.a[2][i]=(x.a[2][i]+w[4][i-1])%Mod;
	}
	
}
signed main()
{
	r[23][1]=ksc(3,119), r[23][0]=ksc(ksc(3,Mod-2),119);
	for(ri int i=22;~i;i--) r[i][0]=1ll*r[i+1][0]*r[i+1][0]%Mod, r[i][1]=1ll*r[i+1][1]*r[i+1][1]%Mod;
	n=read(), K=read();
	for(ri int i=0;i<3;i++) f.a[i].resize(1), g.a[i].resize(2);
	f.a[0][0]=1;
	g.a[0][0]=g.a[1][0]=g.a[0][1]=1;
	for(int p=n;p;p>>=1, Merge(g,g)) if(p&1) Merge(f,g);
	for(ri int i=1;i<=K;i++) printf("%d ",(i<=n)?f.a[0][i]:0);
	puts("");
	return 0;
}

后记:三种做法在 \(\text{CF}\) 上的用时分别为 \(124,1060,3541\)(单位 \(\text{ms}\)),可见实际运行效率还是有一定的差距。

posted @ 2021-04-27 14:29  zkdxl  阅读(69)  评论(0编辑  收藏  举报