CF755G PolandBall and Many Other Balls【矩阵快速幂+多项式乘法】
CF755G PolandBall and Many Other Balls
(看了洛谷上的题解都做得好复杂2333
考虑如果没有选
k
k
k 组的限制那么有转移
F
n
=
F
n
−
2
+
F
n
−
1
+
F
n
−
1
F_n=F_{n-2}+F_{n-1}+F_{n-1}
Fn=Fn−2+Fn−1+Fn−1
把
F
n
F_n
Fn 看成多项式
F
n
(
x
)
F_n(x)
Fn(x) ,其中
[
x
i
]
F
n
(
x
)
[x^i]F_n(x)
[xi]Fn(x) 表示选了
i
i
i 组的方案数。
那么就转移变成了
F
n
(
x
)
=
F
n
−
2
(
x
)
x
+
F
n
−
1
(
x
)
x
+
F
n
−
1
F_n(x)=F_{n-2}(x)x+F_{n-1}(x)x+F_{n-1}
Fn(x)=Fn−2(x)x+Fn−1(x)x+Fn−1
最后矩阵快速幂即可。
时间复杂度 O ( n log 2 n ) O(n\log^2n) O(nlog2n) 。
#include <bits/stdc++.h>
#define N 70000
using namespace std;
typedef long long ll;
typedef vector<ll> vec;
const ll mod=998244353;
int now_limit;
int rev[N];
ll wq[17][N],fac[N],inv[N];
ll ksm(ll x,ll y){
ll res=1; while(y){ if(y&1)res=res*x%mod; x=x*x%mod; y>>=1; }
return res;
}
void init_NTT(int limit){
if(limit==now_limit) return;
now_limit=limit; int l=0; while((1<<l)<limit)l++;
for(int i=0;i<limit;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
if(wq[0][0]==0){
limit=1<<16;
for(int mid=1,l=0;mid<limit;mid<<=1,l++){
wq[l][0]=1,wq[l][1]=ksm(3,(mod-1)/(mid<<1));
for(int k=2;k<mid;k++) wq[l][k]=wq[l][k-1]*wq[l][1]%mod;
}
}
}
void NTT(vec &a,int limit,int flag){
init_NTT(limit);
while(a.size()<limit) a.push_back(0);
for(int i=0;i<limit;i++) if(i<rev[i])swap(a[i],a[rev[i]]);
ll x,y;
for(int mid=1,l=0;mid<limit;mid<<=1,l++)
for(int j=0;j<limit;j+=(mid<<1))
for(int k=0;k<mid;k++){
x=a[j+k],y=a[j+k+mid]*wq[l][k]%mod;
a[j+k]=(x+y)%mod,a[j+k+mid]=(x-y+mod)%mod;
}
if(flag==-1){
x=ksm(limit,mod-2);
for(int i=0;i<limit;i++) a[i]=a[i]*x%mod;
reverse(&a[1],&a[limit]);
}
}
int n,k;
vec operator *(vec x,vec y){
int len=min((int)x.size(),k)+min((int)y.size(),k)-1;
int limit=1; while(limit<len)limit<<=1;
NTT(x,limit,1),NTT(y,limit,1);
for(int i=0;i<limit;i++) x[i]=x[i]*y[i]%mod;
NTT(x,limit,-1);
if(limit>k) for(int i=k;i<limit;i++) x[i]=0;
return x;
}
void operator +=(vec &x,vec y){
int len=max(x.size(),y.size());
x.resize(len),y.resize(len);
for(int i=0;i<len;i++) x[i]=(x[i]+y[i])%mod;
}
struct mtx{
vec a[3][3];
friend mtx operator *(const mtx &x,const mtx &y){
mtx z;
for(int i=1;i<3;i++)
for(int j=1;j<3;j++)
for(int k=1;k<3;k++)
z.a[i][j]+=(x.a[i][k]*y.a[k][j]);
return z;
}
};
int main(){
// freopen("test.in","r",stdin);
// freopen("test.out","w",stdout);
cin>>n>>k; k++;
mtx g,res,f;
res.a[1][1].push_back(1),res.a[2][1].push_back(0);
res.a[1][2].push_back(0),res.a[2][2].push_back(1);
g.a[1][1].push_back(0);
g.a[2][1].push_back(1);
g.a[1][2].push_back(0); g.a[1][2].push_back(1);
g.a[2][2].push_back(1); g.a[2][2].push_back(1);
f.a[1][1].push_back(1);
f.a[1][2].push_back(1); f.a[1][2].push_back(1);
while(n){
if(n&1) res=res*g;
g=g*g;
n>>=1;
}
res=f*res; int m=res.a[1][1].size();
for(int i=1;i<k;i++){
if(m<=i)printf("0 ");
else printf("%lld ",res.a[1][1][i]);
}
}