Codeforces 1349D Bear and Biscuits
设 \(E_i\) 为所有饼干第一次都收到 \(i\) 手上,且游戏在 \(i\) 结束的期望时间,\(P_i\) 为这种情况的概率,再令 \(E'_i\) 为所有饼干第一次收到 \(i\) 手上的概率, \(C\) 为将所有饼干从一个人手上转移到另一个不同人手上的期望时间,那么有:
\[Ans = \sum_{i=1}^n E_{i} \\
E_i = E'_i - \sum_{j=1}^n[i \neq j] E_{j}-P_iC \\
Ans = \sum_{i=1}^n E_i =\sum_{i=1}^n E'_i -(n-1)E_i -(n-1)C\\
nAns = \sum_{i=1}^n E_i'-(n-1)C\\
p_1= \frac{S-x}{S}\times\frac{n-2}{n-1} \\
p_2=\frac{S-x}{S}\times \frac{1}{n-1} \\
p_3 =\frac{x}S{} \\
F_0 = \frac{1}{p_2}\\
F_x = (\frac{p_2+p_3}{p_2}-1)(\frac{1}{1-p_1}+F_{x-1})+ \frac{1}{1-p1}\\
E'_i=\sum_{k=A_i}^SF_{k}, \ C= \sum_{k=0}^S F_k \\
\]
/*program by mangoyang*/
#pragma GCC optimize("Ofast", "inline")
#include<bits/stdc++.h>
#define inf ((ll) 3e18)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
int ch = 0, f = 0; x = 0;
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
if(f) x = -x;
}
const int N = 500005, mod = 998244353;
int f[N], sum[N], a[N], n, S;
inline void up(int &x, int y){
x = x + y >= mod ? x + y - mod : x + y;
}
inline int Pow(int a, int b){
int ans = 1;
for(; b; b >>= 1, a = 1ll * a * a % mod)
if(b & 1) ans = 1ll * ans * a % mod;
return ans;
}
int main(){
read(n);
for(int i = 1; i <= n; i++) read(a[i]), S += a[i];
f[0] = n - 1;
int InvS = Pow(S, mod - 2), C2 = Pow(n - 1, mod - 2);
int C1 = 1ll * (n - 2) * Pow(n - 1, mod - 2) % mod;
for(int i = 1; i < S; i++){
int p1 = 1ll * (S - i) * InvS % mod * C1 % mod;
int p2 = 1ll * (S - i) * InvS % mod * C2 % mod;
int p3 = 1ll * i * InvS % mod;
f[i] = Pow(mod + 1 - p1, mod - 2);
up(f[i], f[i-1]);
int tmp = 1ll * (p2 + p3) % mod * Pow(p2, mod - 2) % mod;
up(tmp, mod - 1);
f[i] = 1ll * tmp * f[i] % mod;
up(f[i], Pow(mod + 1 - p1, mod - 2));
//cout << f[i] << endl;
}
for(int i = S - 1; i >= 0; i--)
sum[i] = (sum[i+1] + f[i]) % mod;
int res = 0;
for(int i = 1; i <= n; i++) up(res, sum[a[i]]);
up(res, mod - 1ll * (n - 1) * sum[0] % mod);
res = 1ll * res * Pow(n, mod - 2) % mod;
cout << res << endl;
return 0;
}