【JZOJ7336】拆分
【JZOJ7336】拆分
by AmanoKumiko
Description
定义\(n\)的拆分为满足\(\sum_{i=1}^ka_i=n\)的\(a\)序列且\(n\)中的元素均为自然数,其中\(a\)有序
给定\(n,m,k\),求\(\sum \prod_{i=1}^ka_i^{m}\)对\(998244353\)取模的结果
Input
一行三个数n,m,k,详细意义如题。
Output
一个数,表示答案
Sample Input
3 1 2
Sample Output
4
Data Constraint
\(0< n\le 10^7,0< k\le 10^3,0\le m \le 10^3\)
Solution
考虑\(m=0\)时怎么做
它的生成函数即为\(F(x)=(\sum _{i=0}^nx^i)^k\),答案即为\([x^n]\),很明显是个组合数
再思考\(m>0\)时
此时的\(F(x)=(\sum_{i=0}^ni^mx^i)^k\),由于若存在\(a_i\)为\(0\),答案就为\(0\)
所以我们左移一下,改为\(F(x)=(\sum_{i=0}^n(i+1)^mx^i)^k\),答案改为求\([x^{n-k}]\)
想办法求封闭形式
对于一个二次函数,发现它实际上是个二阶等差
那么对于一个\(m\)次函数,它应该是一个\(m\)阶等差
例如,\(m=3\)时:
\(1,8,27,64,125,216\)
\(1,7,19,37,61,91\)
\(1,6,12,18,24,30\)
\(1,5,6,6,6,6\)
此时可以发现,再差分一次可以让后面都变为\(0\)
即\(1,4,1,0,0,0\)
也就是说,原式经过\(m+1\)次差分可以变为只有\(m\)项的多项式\(A\)
再思考一次差分怎么表示,实际上就是\((1-x)\)
那么答案应为\([x^{n-k}](\frac{1}{(1-x)^{m+1}}·A)^k\)
两部分分开算,左边是个组合数,右边可以多项式快速幂
Code
#include<bits/stdc++.h>
using namespace std;
#define F(i,a,b) for(int i=a;i<=b;i++)
#define Fd(i,a,b) for(int i=a;i>=b;i--)
#define LL long long
#define mo 998244353
#define N 20000010
#define L 4000010
int n,m,k,rev[L];
LL fac[N],ans;
LL mi(LL x,LL y){
if(y==1)return x;
return y%2?x*mi(x*x%mo,y/2)%mo:mi(x*x%mo,y/2);
}
struct poly{
int len;
LL val[L];
void NTT(int x){
F(i,0,len-1)if(i<rev[i])swap(val[i],val[rev[i]]);
for(int mid=1;mid<len;mid*=2){
LL gn=mi(3,(mo-1)/(mid*2));
if(x==-1)gn=mi(gn,mo-2);
for(int i=0;i<len;i+=mid*2){
LL g=1;
F(j,0,mid-1){
LL le=val[i+j],ri=g*val[i+j+mid]%mo;
val[i+j]=(le+ri)%mo;val[i+j+mid]=(le-ri+mo)%mo;
(g*=gn)%=mo;
}
}
}
if(x==-1){
LL inv=mi(len,mo-2);
F(i,0,len-1)(val[i]*=inv)%=mo;
}
}
void DFT(){NTT(1);}
void IDFT(){NTT(-1);}
}f;
void trans(int x){
int bit=log2(x);
F(i,0,x-1)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
}
void Pow(poly&x,int y,int lim){
int l=1;
while(l<lim+1)l<<=1;trans(l);
x.len=l;
x.DFT();
F(i,0,x.len-1)x.val[i]=mi(x.val[i],y);
x.IDFT();
}
LL C(int x,int y){
return fac[x]*mi(fac[y]*fac[x-y]%mo,mo-2)%mo;
}
int main(){
freopen("split.in","r",stdin);
freopen("split.out","w",stdout);
scanf("%d%d%d",&n,&m,&k);
fac[0]=1;
F(i,1,N-10)fac[i]=fac[i-1]*i%mo;
if(m==0){
printf("%lld",C(n+k-1,n));
return 0;
}
F(i,0,m-1)f.val[i]=mi(i+1,m);
F(i,1,m+1) Fd(j,m-1,1)f.val[j]=(f.val[j]-f.val[j-1]+mo)%mo;
Pow(f,k,m*k);
F(i,0,m*k)(ans+=f.val[i]*C(m*k+k+n-k-i-1,n-k-i)%mo)%=mo;
printf("%lld",ans);
return 0;
}