Codeforces 1349D Bear and Biscuits

​ 设 \(E_i\) 为所有饼干第一次都收到 \(i\) 手上,且游戏在 \(i\) 结束的期望时间,\(P_i\) 为这种情况的概率,再令 \(E'_i\) 为所有饼干第一次收到 \(i\) 手上的概率, \(C\) 为将所有饼干从一个人手上转移到另一个不同人手上的期望时间,那么有:

\[Ans = \sum_{i=1}^n E_{i} \\ E_i = E'_i - \sum_{j=1}^n[i \neq j] E_{j}-P_iC \\ Ans = \sum_{i=1}^n E_i =\sum_{i=1}^n E'_i -(n-1)E_i -(n-1)C\\ nAns = \sum_{i=1}^n E_i'-(n-1)C\\ p_1= \frac{S-x}{S}\times\frac{n-2}{n-1} \\ p_2=\frac{S-x}{S}\times \frac{1}{n-1} \\ p_3 =\frac{x}S{} \\ F_0 = \frac{1}{p_2}\\ F_x = (\frac{p_2+p_3}{p_2}-1)(\frac{1}{1-p_1}+F_{x-1})+ \frac{1}{1-p1}\\ E'_i=\sum_{k=A_i}^SF_{k}, \ C= \sum_{k=0}^S F_k \\ \]

/*program by mangoyang*/
#pragma GCC optimize("Ofast", "inline")
#include<bits/stdc++.h>
#define inf ((ll) 3e18)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
    int ch = 0, f = 0; x = 0;
    for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
    for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
    if(f) x = -x;
}
const int N = 500005, mod = 998244353;
int f[N], sum[N], a[N], n, S;
inline void up(int &x, int y){ 
	x = x + y >= mod ? x + y - mod : x + y; 
}
inline int Pow(int a, int b){
	int ans = 1;
	for(; b; b >>= 1, a = 1ll * a * a % mod)
		if(b & 1) ans = 1ll * ans * a % mod;
	return ans;
}
int main(){
	read(n);
	for(int i = 1; i <= n; i++) read(a[i]), S += a[i];
	f[0] = n - 1;
	int InvS = Pow(S, mod - 2), C2 = Pow(n - 1, mod - 2);
	int C1 = 1ll * (n - 2) * Pow(n - 1, mod - 2) % mod;
	for(int i = 1; i < S; i++){
		int p1 = 1ll * (S - i) * InvS % mod * C1 % mod;
		int p2 = 1ll * (S - i) * InvS % mod * C2 % mod;
		int p3 = 1ll * i * InvS % mod;
		f[i] = Pow(mod + 1 - p1, mod - 2);
		up(f[i], f[i-1]);
		int tmp = 1ll * (p2 + p3) % mod * Pow(p2, mod - 2) % mod;
		up(tmp, mod - 1);
		f[i] = 1ll * tmp * f[i] % mod;
		up(f[i], Pow(mod + 1 - p1, mod - 2));
		//cout << f[i] << endl;
		
	}
	for(int i = S - 1; i >= 0; i--) 
		sum[i] = (sum[i+1] + f[i]) % mod;
	int res = 0;
	for(int i = 1; i <= n; i++) up(res, sum[a[i]]);
	up(res, mod - 1ll * (n - 1) * sum[0] % mod);
	res = 1ll * res * Pow(n, mod - 2) % mod;
	cout << res << endl;
	return 0;
}


posted @ 2020-05-13 22:25  Joyemang33  阅读(301)  评论(1编辑  收藏  举报