BMCH

放板子好啊。

#include <bits/stdc++.h>

using namespace std;
typedef long long i64;
typedef vector<int> vi;
const int M=998244353;
namespace {
	int add(int x,int y){
		return (x+=y)>=M?x-M:x;
	}
	int sub(int x,int y){
		return (x-=y)<0?x+M:x;
	}
	int mul(int x,int y){
		return (i64)x*y%M;
	}
	int fp(int x,int y){
		int ret=1;
		for (; y; y>>=1,x=(i64)x*x%M)
			if (y&1) ret=(i64)ret*x%M;
		return ret;
	}
	int fraction(int x,int y){
		return mul(x,fp(y,M-2));
	}
}
void operator +=(vi &a,const vi &b){
	if (b.size()>a.size()) a.resize(b.size());
	for (int i=0; i<b.size(); ++i)
		a[i]=add(a[i],b[i]);
}
ostream& operator <<(ostream& os,const vi &_){
	os<<"{";
	for (auto i:_) os<<i<<",";
	return os<<"}";
}
istream& operator >>(istream& is,vi &_){
	size_t len;
	is>>len;
	_.resize(len);
	for (auto &i:_) is>>i;
	return is;
}
vi BM(const vi &v){
	int cnt=0;
	vi now,past,fail,dec;
	for (int i=0; i<v.size(); ++i){
		int del=0;
		for (int j=0; j<now.size(); ++j)
			del=add(del,mul(v[i-j-1],now[j]));
		del=sub(v[i],del);
		if (!del) continue;
		if (!fail.size()) now.resize(i+1);
		else{
			int ro=fraction(del,dec.back());
			vi r(i-fail.back()+past.size());
			int tmp=i-fail.back()-1;
			r[tmp]=ro;
			for (int i=0; i<past.size(); ++i)
				r[tmp+1+i]=sub(0,mul(ro,past[i])); 
			if (r.size()<=now.size()){
				now+=r;
				continue;
			}
			past=now;
			now+=r;
		}
		fail.push_back(i);
		dec.push_back(del);
	}
	return now;
}
void mul(int *a,int *b,const int *c,int k){
	static int *tmp=new int[k*2-1];
	memset(tmp,0,sizeof(*tmp)*(k*2-1));
	for (int i=0; i<k; ++i)
		for (int j=0; j<k; ++j)
			tmp[i+j]=(tmp[i+j]+(i64)a[i]*b[j])%M;
	for (int i=k*2-2; i>=k; --i)
		for (int j=1; j<=k; ++j)
			tmp[i-j]=(tmp[i-j]+(i64)tmp[i]*c[j-1])%M;

	memcpy(a,tmp,sizeof(*tmp)*k);
}
int CH(const vi &&a,vi h,int n){
	int k=a.size();
	//cerr<<k<<endl;
	assert(h.size()>=k);
	//cerr<<h<<endl;
	h.resize(2*k);
	for (int i=k; i<2*k; ++i){
		h[i]=0;
		for (int j=0; j<k; ++j)
			h[i]=add(h[i],mul(h[i-j-1],a[j]));
	}
	if (n<2*k) return h[n];
	//cerr<<"TCH"<<endl;
	int *res=new int[k](),*x=new int[k]();
	res[0]=1;
	x[1]=1;
	for (int y=n-k+1; y; y>>=1,mul(x,x,a.data(),k))
		if (y&1) mul(res,x,a.data(),k);
	int ans=0;
	for (int i=0; i<k; ++i) ans=add(ans,mul(h[k-1+i],res[i]));
	return ans;
}
int calc(const vi &val,int n){
	//cerr<<"calc"<<calc<<endl;
	return CH(BM(val),val,n);
}

const int D=2005;
typedef unsigned int u32;
u32 f[21][D*20];
int d[D];
int n,k;
const int LIM=2000;
const u32 MM=M+M;
void Addt(u32 &x,u32 y){
    (x+=y)>=MM?x-=MM:0;
}
int main(){
    cin>>n>>k;
    if (k==1){
        cout<<1<<endl;
        return 0;
    }
    vi d;
    f[1][0]=1;
    f[0][0]=1;
 	d.push_back(1);
    for (int i=1; i<=2*LIM; ++i){
        for (int z=k; z; --z){
        	int dlim=min(i*(i+1)/2,LIM*k);
            u32 *h=&f[z][dlim],*g=&f[z-1][dlim-i];
            for (int j=dlim; j>=i; --j)
                Addt(*(h--),*(g--));
        }
        if (i%2==0)
        d.push_back(f[k][i/2*k]%M);
    }
    cout<<calc(d,n);
}
posted @ 2019-04-10 22:03  Yuhuger  阅读(523)  评论(0编辑  收藏  举报