【BZOJ4944】[NOI2017]泳池(线性常系数齐次递推,动态规划)
【BZOJ4944】[NOI2017]泳池(线性常系数齐次递推,动态规划)
首先恰好为\(k\)很不好算,变为至少或者至多计算然后考虑容斥。
如果是至少的话,我们依然很难处理最大面积这个东西。所以考虑答案至多为\(k\)的概率,再减去至多为\(k-1\)的概率就是最终的答案。
发现要求的东西必须贴着底边,所以对于每一列而言我们需要考虑的就是选定区间的最低的那个不安全的格子的行号,再乘上底边的长度。
所以考虑设\(f[n]\)表示底边长度为\(n\)的答案,即确定底边长度为\(n\)时,面积小于等于\(k\)的答案。
那么我们有这样一个转移:
翻译一下,首先我们枚举一下上一个在底边断开的位置,即上一次某一列的第一行就存在一个障碍,那么就可以知道这一次的底边长度是\(i\),前面符合条件的是\(f[n-i-1]\),然后枚举这一次的高度是多少,确定高度之后乘上在这个高度上至少存在一个障碍的概率即\(dp[i][j]\)值。
注意一下我们这里的模型,是每次考虑一段第一行不为障碍的东西,而分割的地方我们强制存在障碍,即我们考虑完了这一段之后强制在末尾放了一个障碍,所以答案是\(f[n+1]\),而末尾那个障碍是不需要放的,所以我们实际要求的东西是\(\frac{f[n+1]}{1-p}\),其中\(p\)是不是障碍的概率。
那么求出\(dp\)值之后,每次的系数就唯一确定了,转成线性常系数递推。
考虑怎么求解\(dp\)值。
显然这个\(dp\)值要做的就是找到一个位置使得其障碍高度恰好为\(j\),然后其他位置都不小于\(j\)。
那么我们考虑枚举最靠左侧的那个障碍的位置\(j\),那么它是障碍,所以概率是\(1-p\),而它下边的都不是障碍,所以概率是\(p^{j-1}\),,那么先枚举这个最靠左的位置是\(l\),那么我们就可以得到转移:
即考虑其左右的位置,因为这个位置是最靠左的,所以左侧的最低的障碍一定都比这个位置的障碍要高,所以限制条件是\(k\gt j\),而右边无所谓,只要不比这里矮就行了,所以枚举的是\(k\ge j\)。
那么这个\(dp\)方程可以很容易的使用后缀和优化得到。
那么单次转移\(O(k)\),状态总数\(O(klogk)\),这是因为\(i*(j-1)\le k\),所以状态是调和级数级别的。所以这部分的\(dp\)的复杂度是\(O(k^2logk)\)。
接下来维护好后缀和,那么\(f\)数组的求解就是一个线性常系数齐次递推式。
暴力\(O(nk)\),矩乘\(O(k^3logn)\),这就\(90\)分了。
这个东西可以参考这里。
因为特征多项式的系数是\(k\),所以多项式取模和多项式乘法可以暴力,这部分的复杂度就是\(O(k^2)\),加上快速幂的一个\(log\),所以这部分的复杂度就是\(O(k^2logk)\)。
暴力多项式乘法不用说,暴力多项式取模就是模拟长除法的过程,显然是一个\(O(k^2)\)的过程。
综上,本题的复杂度就是\(O(k^2logk)\)。
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
using namespace std;
#define ll long long
#define MOD 998244353
#define MAX 1010
inline int read()
{
int x=0;bool t=false;char ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
if(ch=='-')t=true,ch=getchar();
while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
return t?-x:x;
}
int fpow(int a,int b){int s=1;while(b){if(b&1)s=1ll*s*a%MOD;a=1ll*a*a%MOD;b>>=1;}return s;}
int f[MAX][MAX],g[MAX][MAX],pw[MAX],p,np,n,K;
int M[MAX],pre[MAX<<1],A[MAX<<1],S[MAX<<1],B[MAX<<1],tmp[MAX<<1];
int Val(int a,int b){return 1ll*a*fpow(b,MOD-2)%MOD;}
void Mod(int *S,int len,int K)
{
for(int i=len;i>=K;--i)
{
int t=S[i];
for(int j=0;j<=K;++j)
S[i-j]=(S[i-j]+MOD-1ll*B[K-j]*t%MOD)%MOD;
}
}
int Solve(int K)
{
memset(f,0,sizeof(f));memset(g,0,sizeof(g));
memset(M,0,sizeof(M));memset(pre,0,sizeof(pre));
memset(S,0,sizeof(S));memset(A,0,sizeof(A));
memset(B,0,sizeof(B));memset(tmp,0,sizeof(tmp));
for(int i=1;i<=K+2;++i)f[0][i]=g[0][i]=1;
for(int i=1;i<=K;++i)
for(int j=K/i+1;j;--j)
{
for(int l=1;l<=i;++l)f[i][j]=(f[i][j]+1ll*g[l-1][j+1]*g[i-l][j]%MOD*pw[j-1]%MOD*np)%MOD;
g[i][j]=(g[i][j+1]+f[i][j])%MOD;
}
for(int i=0;i<=K;++i)M[i+1]=1ll*np*g[i][2]%MOD;
pre[0]=1;
for(int i=1;i<=K;++i)
for(int j=1;j<=i;++j)
pre[i]=(pre[i]+1ll*pre[i-j]*M[j])%MOD;
/*
for(int i=K+1;i<=n+1;++i)
for(int j=1;j<=K+1;++j)
pre[i]=(pre[i]+1ll*pre[i-j]*M[j])%MOD;
*/
if(n+1<=K)return 1ll*pre[n+1]*fpow(np,MOD-2)%MOD;
K+=1;
for(int i=0;i<K;++i)B[i]=(MOD-M[K-i])%MOD;B[K]=1;
A[1]=1;S[0]=1;int b=n+1;
while(b)
{
if(b&1)
{
for(int i=0;i<=K;++i)
for(int j=0;j<=K;++j)
tmp[i+j]=(tmp[i+j]+1ll*A[i]*S[j])%MOD;
for(int i=0;i<=K+K;++i)S[i]=tmp[i],tmp[i]=0;
Mod(S,K+K,K);
}
for(int i=0;i<=K;++i)
for(int j=0;j<=K;++j)
tmp[i+j]=(tmp[i+j]+1ll*A[i]*A[j])%MOD;
for(int i=0;i<=K+K;++i)A[i]=tmp[i],tmp[i]=0;
Mod(A,K+K,K);
b>>=1;
}
int ret=0;
for(int i=0;i<K;++i)ret=(ret+1ll*S[i]*pre[i])%MOD;
return 1ll*ret*fpow(np,MOD-2)%MOD;
}
int main()
{
n=read();K=read();p=read();p=1ll*p*fpow(read(),MOD-2)%MOD;np=(1+MOD-p)%MOD;
pw[0]=1;for(int i=1;i<=K;++i)pw[i]=1ll*pw[i-1]*p%MOD;
printf("%d\n",(Solve(K)-Solve(K-1)+MOD)%MOD);
return 0;
}