BMCH
放板子好啊。
#include <bits/stdc++.h>
using namespace std;
typedef long long i64;
typedef vector<int> vi;
const int M=998244353;
namespace {
int add(int x,int y){
return (x+=y)>=M?x-M:x;
}
int sub(int x,int y){
return (x-=y)<0?x+M:x;
}
int mul(int x,int y){
return (i64)x*y%M;
}
int fp(int x,int y){
int ret=1;
for (; y; y>>=1,x=(i64)x*x%M)
if (y&1) ret=(i64)ret*x%M;
return ret;
}
int fraction(int x,int y){
return mul(x,fp(y,M-2));
}
}
void operator +=(vi &a,const vi &b){
if (b.size()>a.size()) a.resize(b.size());
for (int i=0; i<b.size(); ++i)
a[i]=add(a[i],b[i]);
}
ostream& operator <<(ostream& os,const vi &_){
os<<"{";
for (auto i:_) os<<i<<",";
return os<<"}";
}
istream& operator >>(istream& is,vi &_){
size_t len;
is>>len;
_.resize(len);
for (auto &i:_) is>>i;
return is;
}
vi BM(const vi &v){
int cnt=0;
vi now,past,fail,dec;
for (int i=0; i<v.size(); ++i){
int del=0;
for (int j=0; j<now.size(); ++j)
del=add(del,mul(v[i-j-1],now[j]));
del=sub(v[i],del);
if (!del) continue;
if (!fail.size()) now.resize(i+1);
else{
int ro=fraction(del,dec.back());
vi r(i-fail.back()+past.size());
int tmp=i-fail.back()-1;
r[tmp]=ro;
for (int i=0; i<past.size(); ++i)
r[tmp+1+i]=sub(0,mul(ro,past[i]));
if (r.size()<=now.size()){
now+=r;
continue;
}
past=now;
now+=r;
}
fail.push_back(i);
dec.push_back(del);
}
return now;
}
void mul(int *a,int *b,const int *c,int k){
static int *tmp=new int[k*2-1];
memset(tmp,0,sizeof(*tmp)*(k*2-1));
for (int i=0; i<k; ++i)
for (int j=0; j<k; ++j)
tmp[i+j]=(tmp[i+j]+(i64)a[i]*b[j])%M;
for (int i=k*2-2; i>=k; --i)
for (int j=1; j<=k; ++j)
tmp[i-j]=(tmp[i-j]+(i64)tmp[i]*c[j-1])%M;
memcpy(a,tmp,sizeof(*tmp)*k);
}
int CH(const vi &&a,vi h,int n){
int k=a.size();
//cerr<<k<<endl;
assert(h.size()>=k);
//cerr<<h<<endl;
h.resize(2*k);
for (int i=k; i<2*k; ++i){
h[i]=0;
for (int j=0; j<k; ++j)
h[i]=add(h[i],mul(h[i-j-1],a[j]));
}
if (n<2*k) return h[n];
//cerr<<"TCH"<<endl;
int *res=new int[k](),*x=new int[k]();
res[0]=1;
x[1]=1;
for (int y=n-k+1; y; y>>=1,mul(x,x,a.data(),k))
if (y&1) mul(res,x,a.data(),k);
int ans=0;
for (int i=0; i<k; ++i) ans=add(ans,mul(h[k-1+i],res[i]));
return ans;
}
int calc(const vi &val,int n){
//cerr<<"calc"<<calc<<endl;
return CH(BM(val),val,n);
}
const int D=2005;
typedef unsigned int u32;
u32 f[21][D*20];
int d[D];
int n,k;
const int LIM=2000;
const u32 MM=M+M;
void Addt(u32 &x,u32 y){
(x+=y)>=MM?x-=MM:0;
}
int main(){
cin>>n>>k;
if (k==1){
cout<<1<<endl;
return 0;
}
vi d;
f[1][0]=1;
f[0][0]=1;
d.push_back(1);
for (int i=1; i<=2*LIM; ++i){
for (int z=k; z; --z){
int dlim=min(i*(i+1)/2,LIM*k);
u32 *h=&f[z][dlim],*g=&f[z-1][dlim-i];
for (int j=dlim; j>=i; --j)
Addt(*(h--),*(g--));
}
if (i%2==0)
d.push_back(f[k][i/2*k]%M);
}
cout<<calc(d,n);
}