题解 赌神
很好的题
考场上写了巨久可惜没有写出正解
第一思路(其实假了):
发现肯定会把大于2的都消成2(其实假了)
然后对于这个问题,令 \(dp[i][j]\) 为为2的有 \(i\) 个,为1的有 \(j\) 个,进入这个状态时的筹码数变为最优策略下的最大筹码数要乘的系数
那我们令 \(t_1=dp[i-1][j],\ t_2=dp[i][j-1]\)
令进入此状态时筹码数为 \(x\),我们此时给为2的分配 \(k_1x\),为1的分配 \(k_2x\)
对手两种可能的对应收益为 \(y_1=k_1nt_1,\ y_2=k_2nt_2\)
最终收益为 \(min(y1, y2)\)
又有 \(k_1+k_2=1\),可以解出最大值
发现假了之后考虑 \(n^2\) 的部分分:
同理解得 \(dp[i][j] = \frac{2t_1t_2}{t_1+t_2}\)
可以直接DP
但到这里就卡住了
然后题解思路:
前面的柿子是一样的
发现这里很像在二维平面上只能向左、下走的那个情况,而且那个可以直接组合数
但这里有个2的系数,考场上不会处理
神仙思路:
令 \(f[i][j] = \frac{dp[i][j]}{2^{i+j}}\)
然后 \(f[i][j] = \frac{f[i-1][j]f[i][j-1]}{f[i-1][j]+f[i][j-1]}\)
而且 \(f[0][0]=1\),就成了想要的情况了
于是可以直接算,\(ans = \frac{2^{i+j}}{\binom{i+j}{i}}\)
- 当发现一个DP的转移很像在二维平面上只能向左、下走的那个情况,想尝试直接用组合数算
但有些额外的系数不知如何处理时,尝试用当前DP数组构造出一个满足直接转移的辅助数组
比如赌神这题
然后扩展到高维:
- n维平面内从 \((0, 0...0)\) 到 \((x_1, x_2...x_n)\) 的走法数为 \(\frac{(\sum\limits_{i=1}^n x_i)!}{\prod\limits_{i=1}^n x_i!}\)
于是就出来了
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 1000100
#define ll long long
#define ld long double
#define reg register int
#define fir first
#define sec second
#define make make_pair
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n;
int x[N];
const ll mod=998244353;
ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod, b>>=1) if (b&1) ans=ans*a%mod; return ans;}
ld qpow2(ld a, ll b) {ld ans=1; for (; b; a*=a, b>>=1) if (b&1) ans=ans*a; return ans;}
inline ll inv(ll a) {return qpow(a%mod, mod-2);}
namespace task1{
//unordered_map<pair<int, int>, ll> mp;
pair<ld, ll> dfs(ll a, ll b) {
//cout<<"dfs "<<a<<' '<<b<<endl;
if (a==0 && b==1) return make(n, n);
if (a==1 && b==0) return make(1.0*n*n, n*n%mod);
if (a==0) {
pair<ld, ll> t=dfs(a, b-1);
return make((t.fir*n)/b, t.sec*n%mod*inv(b)%mod);
}
if (b==0) {
return dfs(a-1, b+1);
}
//cout<<"pos1: "<<a<<' '<<b<<endl;
pair<ld, ll> t1, t2, k1, k2; ld tem1, tem2;
t1=dfs(a-1, b+1); t2=dfs(a, b-1);
//cout<<"t: "<<t1.fir<<' '<<t2.fir<<endl;
//cout<<"t: "<<t1.sec<<' '<<t2.sec<<endl;
k2.fir=t1.fir/(1.0*a*t2.fir+1.0*b*t1.fir);
k2.sec=t1.sec*inv(a*t2.sec%mod+b*t1.sec%mod);
k1.fir=(1.0-1.0*b*k2.fir)/(1.0*a);
k1.sec=((1ll-b*k2.sec%mod)%mod+mod)%mod*inv(a)%mod;
//cout<<"k: "<<k1.fir<<' '<<k2.fir<<endl;
tem1=k1.fir*n*t1.fir; tem2=k2.fir*n*t2.fir;
//cout<<"return: "<<a<<' '<<b<<' '<<min(tem1, tem2)<<endl;
if (tem1<tem2) return make(tem1, k1.sec*n%mod*t1.sec%mod);
else return make(tem2, k2.sec*n%mod*t2.sec%mod);
}
void solve() {
int a=0, b=0;
for (int i=1; i<=n; ++i)
if (x[i]==1) ++b;
else if (x[i]==2) ++a;
printf("%lld\n", dfs(a, b).sec);
//cout<<"ans: "<<dfs(a, b).fir<<endl;
//cout<<dfs(1, 1).sec<<endl;
exit(0);
}
}
namespace task2{
pair<ld, ll> mp[1050][1050];
bool vis[1050][1050];
pair<ld, ll> dfs(ll a, ll b) {
//cout<<"dfs "<<a<<' '<<b<<endl;
if (a==0 && b==1) return make(n, n);
if (a==1 && b==0) return make(1.0*n*n, n*n%mod);
if (vis[a][b]) return mp[a][b];
if (a==0) {
pair<ld, ll> t=dfs(a, b-1);
vis[a][b]=1;
mp[a][b]=make((t.fir*n)/b, t.sec*n%mod*inv(b)%mod);
return mp[a][b];
}
if (b==0) {
vis[a][b]=1;
mp[a][b]=dfs(a-1, b+1);
return mp[a][b];
}
//cout<<"pos1: "<<a<<' '<<b<<endl;
pair<ld, ll> t1, t2, k1, k2; ld tem1, tem2;
t1=dfs(a-1, b+1); t2=dfs(a, b-1);
//cout<<"t: "<<t1.fir<<' '<<t2.fir<<endl;
//cout<<"t: "<<t1.sec<<' '<<t2.sec<<endl;
k2.fir=t1.fir/(1.0*a*t2.fir+1.0*b*t1.fir);
k2.sec=t1.sec*inv(a*t2.sec%mod+b*t1.sec%mod)%mod;
k1.fir=(1.0-1.0*b*k2.fir)/(1.0*a);
k1.sec=((1ll-b*k2.sec%mod)%mod+mod)%mod*inv(a)%mod;
//cout<<"k: "<<k1.fir<<' '<<k2.fir<<endl;
tem1=k1.fir*n*t1.fir; tem2=k2.fir*n*t2.fir;
//cout<<"return: "<<a<<' '<<b<<' '<<min(tem1, tem2)<<endl;
vis[a][b]=1;
if (tem1<tem2) mp[a][b]=make(tem1, k1.sec*n%mod*t1.sec%mod);
else mp[a][b]=make(tem2, k2.sec*n%mod*t2.sec%mod);
return mp[a][b];
}
void solve() {
int a=0, b=0;
for (int i=1; i<=n; ++i)
if (x[i]==1) ++b;
else if (x[i]==2) ++a;
printf("%lld\n", dfs(a, b).sec);
//cout<<"ans: "<<dfs(a, b).fir<<endl;
//cout<<dfs(1, 1).sec<<endl;
exit(0);
}
}
namespace task3{
ll dp[1050][1050];
void solve() {
dp[1][0]=dp[0][1]=n;
for (int i=2; i<=max(x[1], x[2]); ++i) {
dp[i][0]=dp[i-1][0]*n%mod;
dp[0][i]=dp[0][i-1]*n%mod;
}
for (int i=1; i<=x[1]; ++i) {
for (int j=1; j<=x[2]; ++j) {
ll t1=dp[i-1][j], t2=dp[i][j-1];
dp[i][j] = 1ll*n*t1%mod*t2%mod * inv(t1+t2)%mod;
}
}
printf("%lld\n", dp[x[1]][x[2]]);
exit(0);
}
}
namespace task{
ll fac[N], sum, prod=1;
void solve() {
fac[0]=fac[1]=1;
for (int i=2; i<N; ++i) fac[i]=fac[i-1]*i%mod;
for (int i=1; i<=n; ++i) sum+=x[i], prod=prod*fac[x[i]]%mod;
printf("%lld\n", qpow(n, sum)*prod%mod*inv(fac[sum])%mod);
exit(0);
}
}
signed main()
{
n=read();
for (int i=1; i<=n; ++i) x[i]=read();
if (n==1) {printf("%lld\n", qpow(n, x[1])); return 0;}
task::solve();
return 0;
}