洛谷P6395 千年食谱颂
很脑洞的 \(dp\) 题
首先我们令 \(dp_i\) 表示从 \(i-1\) 转移到 \(i\) 的期望时间
即从在场的 \(i-1\) 个已被吃转移到在场的 \(i\) 个已被吃
考虑转移:
- 在上一时刻和该时刻的时间间隔内撤下的不包含已选的
这种情况的转移很简单,令 \(p\) 为 \(\frac{a}{b}\),则:
\[dp_i+=(1-p)^{i-1} \left ( \frac{i-1}{n}(dp_i+1)+\frac{n-i+1}{n} \right )
\]
- 在上一时刻和该时刻的时间间隔内撤下的包含已选的
按 \(dp\) 转移意转移即可,首先枚举选中了几个,然后考虑在该时刻选的是已选的还是未选的,即:
\[dp_i+=\sum_{k=1}^{i-1}p^{k}(1-p)^{i-1-k}\binom{i-1}{k}\left ( 1+dp_i+\frac{i-1-k}{n}\sum_{j=i-1-k+1}^{i-1}dp_j+\frac{n-i+1+k}{n}\sum_{j=i-1-k+2}^{i-1}dp_j \right )
\]
第二个转移方程内的求和可以用前缀和优化,所以复杂度为常数较小的 \(O(n^2)\)
Code
#include <bits/stdc++.h>
#define re register
#define int long long
#define ll long long
// #define lls long long
#define pir make_pair
#define fr first
#define sc second
#define db double
using namespace std;
const int mol=998244353;
const int maxn=1e7+10;
const int INF=1e9+10;
inline int qpow(int a,int b) { int ans=1; while(b) { if(b&1) (ans*=a)%=mol; (a*=a)%=mol; b>>=1; } return ans; }
inline int read() {
int s=0,w=1; char ch=getchar();
while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
while(ch>='0'&&ch<='9') { s=s*10+ch-'0'; ch=getchar(); }
return s*w;
}
int n,a,b,invn,ye[maxn],no[maxn],dp[maxn],fac[maxn],inv[maxn],sum[maxn];
inline int getsum(int l,int r) { if(l>r) return 0; return (sum[r+1]-sum[l]+mol)%mol; }
inline int C(int n,int m) { return fac[n]*inv[m]%mol*inv[n-m]%mol; }
signed main(void) {
n=read(); a=read(); b=read(); invn=qpow(n,mol-2); ye[1]=a*qpow(b,mol-2)%mol; no[1]=(1ll-ye[1]+mol)%mol;
ye[0]=no[0]=1; for(re int i=2;i<=n;i++) ye[i]=ye[i-1]*ye[1]%mol,no[i]=no[i-1]*no[1]%mol;
fac[0]=1; for(re int i=1;i<=n;i++) fac[i]=fac[i-1]*i%mol;
inv[n]=qpow(fac[n],mol-2); for(re int i=n;i>=1;i--) inv[i-1]=inv[i]*i%mol;
for(re int i=1;i<=n;i++) {
int s=0,xi=0;
for(re int k=1;k<=i-1;k++) {
(s+=ye[k]*no[i-1-k]%mol*C(i-1,k)%mol*((i-1-k)*invn%mol*getsum(i-1-k+1,i-1)%mol+(n-i+1+k)*invn%mol*getsum(i-1-k+2,i-1)%mol+1))%=mol;
(xi+=ye[k]*no[i-1-k]%mol*C(i-1,k)%mol)%=mol;
}
(xi+=no[i-1]*(i-1)%mol*invn%mol)%=mol;
(s+=no[i-1])%=mol;
xi=(1-xi+mol)%mol;
dp[i]=s*qpow(xi,mol-2)%mol;
sum[i+1]=(sum[i]+dp[i])%mol;
}
printf("%lld\n",sum[n+1]);
}