「NOI2017」泳池

「NOI2017」泳池

可以发现每一列出现指定高度的安全位置的概率是可以预处理的,设概率为 \(w_i\)

由于连续面积不超过 \(k\),所以我们可以优先预处理出连续 \(k\) 个以内高度 \(>0\) 的方案数

要求连续 \(k\) 个高度 \(>0\) 的,我们还可以进一步降维,求连续 \(k/2\)\(>1\),然后继续递归。。。

每次转移都是一段从上一层转移下来然后,紧接着一段是空的

最终迭代得到 \(A[i]\) 表示连续 \(i\) 个高度 \(>0\) 的列的概率,复杂度 \(O(k^2\ln k)\) (以下)

可以得到最终的 \(dp\) 转移方程

\(dp_i=\sum_1^kdp_{i-j}A_{j-1}w_0\) 表示我们强制这一位是空的

可以发现这是一个简单的线性递推关系,如果使用矩阵优化,可以做到 \(O(k^3)logn\)

正解我们要用到 常系数线性齐次递推

表示我特别喜欢用vector实现多项式问题。。。

很多大佬的板子都经过卡常,可读性并不是太高。。。

#include<cstdio>
#include<cctype>
#include<queue>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;

#define reg register
typedef long long ll;
#define rep(i,a,b) for(int i=a,i##end=b;i<=i##end;++i)
#define drep(i,a,b) for(int i=a,i##end=b;i>=i##end;--i)

#define pb push_back
template <class T> inline void cmin(T &a,T b){ ((a>b)&&(a=b)); }
template <class T> inline void cmax(T &a,T b){ ((a<b)&&(a=b)); }

char IO;
template<class T=int> T rd(){
	T s=0;
	int f=0;
	while(!isdigit(IO=getchar())) if(IO=='-') f=1;
	do s=(s<<1)+(s<<3)+(IO^'0');
	while(isdigit(IO=getchar()));
	return f?-s:s;
}

const int N=(1<<10)+10,P=998244353;


ll qpow(ll x,ll k){
	ll res=1;
	for(;k;k>>=1,x=x*x%P) if(k&1) res=res*x%P;
	return res;
}

namespace Polynomial{
	int rev[N<<2];
	void NTT(int n,vector <int> &a,int f){
		rep(i,0,n-1) if(i<rev[i]) swap(a[i],a[rev[i]]);
		for(int i=1;i<n;i<<=1){
			ll w=qpow(f==1?3:(P+1)/3,(P-1)/i/2);
			for(int l=0;l<n;l+=i*2){
				ll e=1;
				for(int j=l;j<l+i;++j,e=e*w%P) {
					ll t=a[j+i]*e%P;
					a[j+i]=(a[j]-t+P)%P;
					a[j]=(a[j]+t)%P;
				}
			}
		}
		if(f==-1){
			ll t=qpow(n,P-2);
			rep(i,0,n-1) a[i]=a[i]*t%P;
		}
	}

	int PreMake(int n){
		int R=1,cc=-1;
		while(R<=n) R<<=1,cc++;
		rep(i,1,R-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<cc);
		return R;
	}


	vector <int> operator * (vector <int> a,vector <int> b){
		int n=a.size()+b.size()-1,R=PreMake(n);
		a.resize(R),b.resize(R);
		NTT(R,a,1),NTT(R,b,1);
		rep(i,0,R-1) a[i]=1ll*a[i]*b[i]%P;
		NTT(R,a,-1);
		return a;
	}

	vector <int> Inv(vector <int> a,int n){
		if(n==1){ vector <int> tmp; tmp.pb(qpow(a[0],P-2)); return tmp; }
		a.resize(n);
		vector <int> b=Inv(a,(n+1)>>1);
		int R=PreMake(n*2);
		a.resize(R),b.resize(R);
		NTT(R,a,1),NTT(R,b,1);
		rep(i,0,R-1) b[i]=(P+2-1ll*a[i]*b[i]%P)*b[i]%P;
		NTT(R,b,-1);
		b.resize(n);
		return b;
	}

	vector <int> operator / (vector <int> a,vector <int> b) {
		int n=a.size(),m=b.size();
		reverse(a.begin(),a.end()),reverse(b.begin(),b.end());
		b.resize(n-m+1);
		b=Inv(b,n-m+1), a=a*b;
		a.resize(n-m+1);
		reverse(a.begin(),a.end());
		return a;
	}

	vector <int> operator - (vector <int> a,vector <int> b){
		int sz=max(a.size(),b.size());
		a.resize(sz),b.resize(sz);
		rep(i,0,a.size()-1) a[i]=(a[i]-b[i]+P)%P;
		return a;
	}

	vector <int> operator % (vector <int> a,vector <int> b){
		a=a-a/b*b;
		a.resize(b.size()-1);
		return a;
	}

	void Show(vector <int> a){ for(int i:a) printf("%d ",i); puts(""); }
	int st[N];
}
using namespace Polynomial;


int n,k;
ll p;
ll w[N],s[N],t[N],h[N],dp[N];
ll Calc(int k) {
	memset(s,0,sizeof s),memset(t,0,sizeof t);
	drep(i,k,1) {
		if(i==k) {
			t[1]=w[k];
			continue;
		}
		memset(s,0,sizeof s),memset(h,0,sizeof h);
		s[0]=1;
		rep(j,0,k/i) {
			rep(d,j+1,k/i) {
				if(d==j+1) s[d]=(s[d]+s[j]*w[i])%P;
				else s[d]=(s[d]+s[j]*w[i]%P*t[d-j-1])%P;
			}
		}
		rep(j,0,k/i) {
			h[j]=(h[j]+s[j])%P;
			rep(d,j+1,k/i) h[d]=(h[d]+s[j]*t[d-j])%P;
		}
		memcpy(t,h,sizeof t);
	}
	memcpy(s,t,sizeof t);
	dp[0]=s[0]=1;
	rep(i,0,k) s[i]=s[i]*w[0]%P;
	drep(i,++k,1) s[i]=s[i-1];

	rep(i,1,N-1) {
		dp[i]=0;
		rep(j,max(i-k,0),i-1) dp[i]=(dp[i]+dp[j]*s[i-j])%P;
	}
	if(n+1<N) return dp[n+1]*qpow(w[0],P-2)%P;

	int p=n+1;
	vector <int> Mod,x,res;
	drep(i,k,1) Mod.pb(s[i]); 
	rep(i,0,Mod.size()-1) Mod[i]=(P-Mod[i])%P; 
	Mod.pb(1);
	x.resize(k),x[1]=1,res.resize(k),res[0]=1;
	for(;p;p>>=1,x=x*x%Mod) if(p&1) res=res*x%Mod;
	ll ans=0;
	rep(i,0,k-1) ans=(ans+1ll*res[i]*dp[i])%P;
	ans=ans*qpow(w[0],P-2)%P;
	return ans;
}

int main(){
	n=rd(),k=rd();
	p=rd(),p=p*qpow(rd(),P-2)%P;
	w[0]=(P+1-p)%P;
	rep(i,1,k) w[i]=w[i-1]*p%P;
	if(n==1) {
		printf("%lld\n",w[k]);
		return 0;
	}
	printf("%lld\n",(Calc(k)-Calc(k-1)+P)%P);
}


posted @ 2020-04-17 19:02  chasedeath  阅读(137)  评论(0编辑  收藏  举报