题解-CF755G PolandBall and Many Other Balls
题面
给定 \(n\) 和 \(m\)。有一排 \(n\) 个球,求对于每个 \(1\le k\le m\),选出 \(k\) 个球或相邻的球不能重复的方案数。
数据范围:\(1\le n\le 10^9\),\(1\le m<2^{15}\)。
路标
这题是老经典题了,前人的描述也足够充分了。
但是蒟蒻尝试了这题的 \(3\) 种做法并记在笔记中后还是忍不住去挣咕值分享给大家。
这是的三种做法对比图(从下到上依次是倍增(\(\Theta(n\log^2 n)\))、组合容斥(\(\Theta(n\log n)\))和特征方程(\(\Theta(n\log n)\))):
代码中的 \(n\) 和 \(m\) 可能和题面中不是同个东西 /wul
。
组合容斥
考虑枚举两个球的组的数量,这应该是最容易想到的式子了:
可是这个东西很难推,于是重定义它的组合意义:
先从前 \(k\) 个选任意个,然后再从剩下的(也可以是前 \(k\) 个!)当中选 \(k\) 个的方案数。
抓住容斥切口:前后没有重复。
设 \(p(i)\) 为重复 \(i\) 个剩下不重复的方案数。
设 \(q(i)\) 为钦定重复 \(i\) 个剩下随意的方案数。
还可以用“从前 \(k\) 个中先把两次重复的 \(i\) 个选了,然后对于前面剩下的每个随便选不选,后面的选 \(k-i\) 个”得出 \(q(i)\)。
所以上式等于
\(n\) 太大了,不能直接卷,得化下降幂:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef double db;
#define x first
#define y second
#define bg begin()
#define ed end()
#define pb push_back
#define mp make_pair
#define sz(a) int((a).size())
#define R(i,n) for(int i(0);i<(n);++i)
#define L(i,n) for(int i((n)-1);i>=0;--i)
const int iinf=0x3f3f3f3f;
const ll linf=0x3f3f3f3f3f3f3f3f;
//Data
const int N=1<<15,pN=1<<16,mN=N;
int n,m,a[pN],b[pN];
//Math
const int mod=998244353;
void fmod(int &x){x+=x>>31&mod;}
int Pow(int a,int x,int res=1){
for(;x;x>>=1,a=1ll*a*a%mod) (x&1)&&(res=1ll*res*a%mod);
return res;
}
int fac[mN],ifac[mN],dfac[mN],idfac[mN],pw=1;
//Poly
const int G=3,iG=Pow(3,mod-2);
int pn,rev[pN];
#define cle(p) fill((p)+n,(p)+pn,0)
int up2(int n){return 1<<int(ceil(log2(n)));}
void revn(){R(i,pn) rev[i]=(rev[i>>1]>>1)|((i&1)*(pn>>1));}
void ntt(int* p,bool t){
R(i,pn)if(i<rev[i]) swap(p[i],p[rev[i]]);
for(int mid=1;mid<pn;mid<<=1)
for(int wn=Pow(t?iG:G,(mod-1)/(mid<<1)),i=0;i<pn;i+=mid<<1)
for(int j=i,w=1,x,y;j<(mid|i);j++,w=1ll*w*wn%mod)
x=p[j],y=1ll*p[mid|j]*w%mod,fmod(p[j]+=y-mod),fmod(p[mid|j]=x-y);
if(t){int in=Pow(pn,mod-2); R(i,pn) p[i]=1ll*in*p[i]%mod;}
}
//Main
int main(){
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>m>>n,++n,pn=up2(n<<1),revn();
fac[0]=dfac[0]=1;
R(i,n-1) fac[i+1]=(1ll+i)*fac[i]%mod;
R(i,min(n-1,m))dfac[i+1]=1ll*(m-i)*dfac[i]%mod;
ifac[n-1]=Pow(fac[n-1],mod-2);
idfac[min(n-1,m)]=Pow(dfac[min(n-1,m)],mod-2);
L(i,n-1) ifac[i]=(1ll+i)*ifac[i+1]%mod;
L(i,min(n-1,m)) idfac[i]=1ll*(m-i)*idfac[i+1]%mod;
R(i,n) b[i]=1ll*ifac[i]*ifac[i]%mod*pw%mod,fmod(pw=(pw<<1)-mod),
a[i]=1ll*ifac[i]*idfac[i]%mod,(i&1)&&(fmod(a[i]=-a[i]),true);
ntt(a,false),ntt(b,false);
R(i,pn) a[i]=1ll*a[i]*b[i]%mod; ntt(a,true),cle(a);
R(i,n) a[i]=1ll*a[i]*fac[i]%mod*dfac[i]%mod;
R(i,n)if(i) cout<<a[i]<<' '; cout<<'\n';
return 0;
}
倍增
设 \(f_{n,k}\) 表示前 \(n\) 个球分成 \(k\) 组的方案数,简单的 dp
式:
设 \(F_n(x)=\sum_{k=0}^n f_{n,k}x^k\),可通过上式推出:
这个东西在这里还没用,但是后面特征方程做法的时候会用到。
考虑将两排球合并,分类讨论中间的断口是否切开了一组相邻的球:
然后就可以倍增卷积了,要维护 \(F_{2^k},F_{2^k-1},F_{2^k-2}\):
然后设当前的 \(F_{m}\) 要与倍增数组合并,其实只需要维护 \(F_{m}\) 和 \(F_{m-1}\) 即可。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef double db;
#define x first
#define y second
#define bg begin()
#define ed end()
#define pb push_back
#define mp make_pair
#define sz(a) int((a).size())
#define R(i,n) for(int i(0);i<(n);++i)
#define L(i,n) for(int i((n)-1);i>=0;--i)
const int iinf=0x3f3f3f3f;
const ll linf=0x3f3f3f3f3f3f3f3f;
//Data
const int N=1<<15,pN=1<<16,D=31;
int m,n,f[3][pN],ns[2][pN],tmp[3][pN];
//Math
const int mod=998244353;
void fmod(int &x){x+=x>>31&mod;}
int Pow(int a,int x,int res=1){
for(;x;x>>=1,a=1ll*a*a%mod) (x&1)&&(res=1ll*res*a%mod);
return res;
}
//Poly
const int G=3,iG=Pow(G,mod-2);
int pn,rev[pN];
void revn(){R(i,pn) rev[i]=(rev[i>>1]>>1)|((i&1)*(pn>>1));}
void ntt(int* p,bool t){
R(i,pn)if(i<rev[i]) swap(p[i],p[rev[i]]);
for(int mid=1;mid<pn;mid<<=1)
for(int wn=Pow(t?iG:G,(mod-1)/(mid<<1)),i=0;i<pn;i+=mid<<1)
for(int j=i,w=1,x,y;j<(mid|i);j++,w=1ll*w*wn%mod)
x=p[j],y=1ll*w*p[mid|j]%mod,fmod(p[j]=x+y-mod),fmod(p[mid|j]=x-y);
if(t){int in=Pow(pn,mod-2); R(i,pn) p[i]=1ll*p[i]*in%mod;}
}
//Main
int main(){
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>m>>n,++n,pn=1;
while(pn<(n<<1)) pn<<=1; revn();
R(d,D){
if(d==0) f[0][0]=f[0][1]=f[1][0]=1,ns[0][0]=1;
else {
R(t,3)R(i,pn) tmp[t][i]=(i<n?f[t][i]:0);
R(t,3) ntt(tmp[t],false);
R(i,pn){
f[0][i]=1ll*tmp[0][i]*tmp[0][i]%mod;
f[1][i]=1ll*tmp[1][i]*tmp[0][i]%mod;
f[2][i]=1ll*tmp[1][i]*tmp[1][i]%mod;
tmp[0][i]=1ll*tmp[1][i]*tmp[1][i]%mod;
tmp[1][i]=1ll*tmp[1][i]*tmp[2][i]%mod;
tmp[2][i]=1ll*tmp[2][i]*tmp[2][i]%mod;
}
R(t,3) ntt(tmp[t],true),ntt(f[t],true);
R(t,3)R(i,n-1) fmod(f[t][i+1]+=tmp[t][i]-mod);
R(t,3) fill(f[t]+n,f[t]+pn,0);
}
if(m>>d&1^1) continue;
// cout<<"@DEBUG d="<<d<<'\n';
// R(i,pn) cout<<f[0][i]<<' ';cout<<'\n';
R(t,3)R(i,pn) tmp[t][i]=(i<n?f[t][i]:0);
R(t,3) ntt(tmp[t],false);
R(t,2) ntt(ns[t],false);
R(i,pn){
int t[2]={ns[0][i],ns[1][i]};
ns[0][i]=1ll*t[0]*tmp[0][i]%mod;
ns[1][i]=1ll*t[0]*tmp[1][i]%mod;
tmp[0][i]=1ll*tmp[1][i]*t[1]%mod;
tmp[1][i]=1ll*tmp[2][i]*t[1]%mod;
}
R(t,2) ntt(ns[t],true),ntt(tmp[t],true);
R(t,2)R(i,n-1) fmod(ns[t][i+1]+=tmp[t][i]-mod);
R(t,2) fill(ns[t]+n,ns[t]+pn,0);
}
R(i,n)if(i) cout<<ns[0][i]<<' '; cout<<'\n';
return 0;
}
特征方程
参考倍增卷积做法的前半部分:
这东西可以多项式特征方程:设有多项式 \(a,b\) 满足:
已知 \((a+b)=(x+1),ab=-x\),所以 \(a,b\) 是特征方程的两根:
设 \(a=\frac{x+1+\sqrt{x^2+6x+1}}{2}\),\(b=\frac{x+1-\sqrt{x^2+6x+1}}{2}\)。
由于 \(\Delta>0\),那个根号肯定 \(>0\),\(a\neq b\)。
加点文字防止看得眼花 /qq
。
然后带上多项式全家桶就可以做了。
当时推到 \(\color{orange}{(1)}\) 那里的时候其实已经可以做了,然后蒟蒻点开题解发现和鰰的不一样,蒟蒻不是很懂为什么
然后 qq
群中的 Elegia
队长还有鰰现场回复:因为这个式子常数项为 \(0\)!
学数学学傻了,那个多项式 \(x^2+6x+1\) 开根后常数项肯定是 \(1\),其实是可以和前面的 \(1\) 抵消的。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef double db;
#define x first
#define y second
#define bg begin()
#define ed end()
#define pb push_back
#define mp make_pair
#define sz(a) int((a).size())
#define R(i,n) for(int i(0);i<(n);++i)
#define L(i,n) for(int i((n)-1);i>=0;--i)
const int iinf=0x3f3f3f3f;
const ll linf=0x3f3f3f3f3f3f3f3f;
//Data
const int N=1<<15,pN=1<<16,mN=pN;
int n,m,pm,a[pN],b[pN],c[pN];
//Math
const int mod=998244353;
void fmod(int &x){x+=x>>31&mod;}
int Pow(int a,int x,int res=1){
for(;x;x>>=1,a=1ll*a*a%mod) (x&1)&&(res=1ll*res*a%mod);
return res;
}
int miv[mN];
void math_init(){
miv[1]=1;
R(i,mN)if(i>=2) fmod(miv[i]=-1ll*mod/i*miv[mod%i]%mod);
}
//Poly
const int G=3,iG=Pow(G,mod-2);
int pn,rev[pN],red[pN],blue[pN],yel[pN];
#define cle(p) fill((p)+n,(p)+pn,0)
int up2(int n){return 1<<int(ceil(log2(n)));}
void revn(){R(i,pn) rev[i]=(rev[i>>1]>>1)|((i&1)*(pn>>1));}
void ntt(int* p,bool t){
R(i,pn)if(i<rev[i]) swap(p[i],p[rev[i]]);
for(int mid=1;mid<pn;mid<<=1)
for(int wn=Pow(t?iG:G,(mod-1)/(mid<<1)),i=0;i<pn;i+=mid<<1)
for(int j=i,w=1,x,y;j<(mid|i);j++,w=1ll*w*wn%mod)
x=p[j],y=1ll*w*p[mid|j]%mod,fmod(p[j]=x+y-mod),fmod(p[mid|j]=x-y);
if(t){int in=Pow(pn,mod-2); R(i,pn) p[i]=1ll*p[i]*in%mod;}
}
void deint(int* p,bool t){
if(t){L(i,pn)if(i) p[i]=1ll*p[i-1]*miv[i]%mod; p[0]=0;}
else {R(i,pn-1) p[i]=1ll*p[i+1]*(i+1)%mod; p[pn-1]=0;}
}
void mul(int* p,int* q,int n){
pn=n<<1,revn(),cle(p),cle(q),ntt(p,false),ntt(q,false);
R(i,pn) p[i]=1ll*p[i]*q[i]%mod; ntt(p,true),cle(q);
}
void inv(int* p,int* q,int n){
if(n==1) return void(q[0]=Pow(p[0],mod-2));
inv(p,q,n>>1),pn=n<<1,revn(),copy(p,p+n,red);cle(red),cle(q),ntt(q,false),ntt(red,false);
R(i,pn) fmod(q[i]=(-1ll*red[i]*q[i]%mod+2)*q[i]%mod); ntt(q,true),cle(q);
}
void ln(int* p,int* q,int n){
inv(p,q,n),pn=n<<1,revn(),copy(p,p+n,red),cle(red),cle(q);
deint(red,false),ntt(q,false),ntt(red,false);
R(i,pn) q[i]=1ll*q[i]*red[i]%mod; ntt(q,true),deint(q,true),cle(q);
}
void exp(int* p,int* q,int n){
if(n==1) return void(q[0]=1);
exp(p,q,n>>1),ln(q,blue,n),pn=n<<1,revn(),fmod(blue[0]-=1);
R(i,n) fmod(blue[i]=p[i]-blue[i]); ntt(q,false),ntt(blue,false);
R(i,pn) q[i]=1ll*q[i]*blue[i]%mod; ntt(q,true),cle(q),cle(blue);
}
void sqrt(int* p,int* q,int n){
if(n==1) return void(q[0]=1);
sqrt(p,q,n>>1),inv(q,blue,n),pn=n<<1,revn(),copy(p,p+n,yel);
cle(yel),cle(blue),ntt(blue,false),ntt(yel,false);
R(i,pn) blue[i]=1ll*blue[i]*yel[i]%mod; ntt(blue,true),cle(blue);
R(i,n) fmod(q[i]+=blue[i]-mod),q[i]=1ll*(mod+1)/2*q[i]%mod; cle(q);
}
void pow(int* p,int* q,int x,int n){
ln(p,yel,n); R(i,n) yel[i]=1ll*yel[i]*x%mod; exp(yel,q,n);
}
//Main
int main(){
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>m>>n,++n,math_init(),pm=up2(n);
a[0]=1,a[1]=6,a[2]=1,sqrt(a,b,pm),inv(b,a,pm);
fmod(b[0]+=1-mod),fmod(b[1]+=1-mod);
R(i,pm) b[i]=499122177ll*b[i]%mod;
pow(b,c,(m+1)%mod,pm),mul(a,c,pm);
R(i,n)if(i) cout<<(i<=m?a[i]:0)<<' '; cout<<'\n';
return 0;
}
祝大家学习愉快!