洛谷P6395 千年食谱颂

很脑洞的 \(dp\)

首先我们令 \(dp_i\) 表示从 \(i-1\) 转移到 \(i\) 的期望时间

即从在场的 \(i-1\) 个已被吃转移到在场的 \(i\) 个已被吃

考虑转移:

  • 在上一时刻和该时刻的时间间隔内撤下的不包含已选的

这种情况的转移很简单,令 \(p\)\(\frac{a}{b}\),则:

\[dp_i+=(1-p)^{i-1} \left ( \frac{i-1}{n}(dp_i+1)+\frac{n-i+1}{n} \right ) \]

  • 在上一时刻和该时刻的时间间隔内撤下的包含已选的

\(dp\) 转移意转移即可,首先枚举选中了几个,然后考虑在该时刻选的是已选的还是未选的,即:

\[dp_i+=\sum_{k=1}^{i-1}p^{k}(1-p)^{i-1-k}\binom{i-1}{k}\left ( 1+dp_i+\frac{i-1-k}{n}\sum_{j=i-1-k+1}^{i-1}dp_j+\frac{n-i+1+k}{n}\sum_{j=i-1-k+2}^{i-1}dp_j \right ) \]

第二个转移方程内的求和可以用前缀和优化,所以复杂度为常数较小的 \(O(n^2)\)

Code
#include <bits/stdc++.h>
#define re register
#define int long long
#define ll long long
// #define lls long long
#define pir make_pair
#define fr first 
#define sc second
#define db double
using namespace std;
const int mol=998244353;
const int maxn=1e7+10;
const int INF=1e9+10;
inline int qpow(int a,int b) { int ans=1; while(b) { if(b&1) (ans*=a)%=mol; (a*=a)%=mol; b>>=1; } return ans; }
inline int read() {
    int s=0,w=1; char ch=getchar();
    while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
    while(ch>='0'&&ch<='9') { s=s*10+ch-'0'; ch=getchar(); }
    return s*w;
}

int n,a,b,invn,ye[maxn],no[maxn],dp[maxn],fac[maxn],inv[maxn],sum[maxn];
inline int getsum(int l,int r) { if(l>r) return 0; return (sum[r+1]-sum[l]+mol)%mol; }
inline int C(int n,int m) { return fac[n]*inv[m]%mol*inv[n-m]%mol; }
signed main(void) {
	n=read(); a=read(); b=read(); invn=qpow(n,mol-2); ye[1]=a*qpow(b,mol-2)%mol; no[1]=(1ll-ye[1]+mol)%mol;
 	ye[0]=no[0]=1; for(re int i=2;i<=n;i++) ye[i]=ye[i-1]*ye[1]%mol,no[i]=no[i-1]*no[1]%mol;
	fac[0]=1; for(re int i=1;i<=n;i++) fac[i]=fac[i-1]*i%mol;
	inv[n]=qpow(fac[n],mol-2); for(re int i=n;i>=1;i--) inv[i-1]=inv[i]*i%mol;
	for(re int i=1;i<=n;i++) {
		int s=0,xi=0;
		for(re int k=1;k<=i-1;k++) {
			(s+=ye[k]*no[i-1-k]%mol*C(i-1,k)%mol*((i-1-k)*invn%mol*getsum(i-1-k+1,i-1)%mol+(n-i+1+k)*invn%mol*getsum(i-1-k+2,i-1)%mol+1))%=mol;
			(xi+=ye[k]*no[i-1-k]%mol*C(i-1,k)%mol)%=mol;
		}
		(xi+=no[i-1]*(i-1)%mol*invn%mol)%=mol;
		(s+=no[i-1])%=mol;
		xi=(1-xi+mol)%mol;
		dp[i]=s*qpow(xi,mol-2)%mol;
		sum[i+1]=(sum[i]+dp[i])%mol;
	}
	printf("%lld\n",sum[n+1]);
}
posted @ 2021-11-01 19:45  zJx-Lm  阅读(39)  评论(0编辑  收藏  举报