题解 赌神

传送门

很好的题
考场上写了巨久可惜没有写出正解

第一思路(其实假了):
发现肯定会把大于2的都消成2(其实假了)
然后对于这个问题,令 \(dp[i][j]\) 为为2的有 \(i\) 个,为1的有 \(j\) 个,进入这个状态时的筹码数变为最优策略下的最大筹码数要乘的系数
那我们令 \(t_1=dp[i-1][j],\ t_2=dp[i][j-1]\)
令进入此状态时筹码数为 \(x\),我们此时给为2的分配 \(k_1x\),为1的分配 \(k_2x\)
对手两种可能的对应收益为 \(y_1=k_1nt_1,\ y_2=k_2nt_2\)
最终收益为 \(min(y1, y2)\)
又有 \(k_1+k_2=1\),可以解出最大值

发现假了之后考虑 \(n^2\) 的部分分:
同理解得 \(dp[i][j] = \frac{2t_1t_2}{t_1+t_2}\)
可以直接DP

但到这里就卡住了
然后题解思路:
前面的柿子是一样的
发现这里很像在二维平面上只能向左、下走的那个情况,而且那个可以直接组合数
但这里有个2的系数,考场上不会处理
神仙思路:
\(f[i][j] = \frac{dp[i][j]}{2^{i+j}}\)
然后 \(f[i][j] = \frac{f[i-1][j]f[i][j-1]}{f[i-1][j]+f[i][j-1]}\)
而且 \(f[0][0]=1\),就成了想要的情况了
于是可以直接算,\(ans = \frac{2^{i+j}}{\binom{i+j}{i}}\)

  • 当发现一个DP的转移很像在二维平面上只能向左、下走的那个情况,想尝试直接用组合数算
    但有些额外的系数不知如何处理时,尝试用当前DP数组构造出一个满足直接转移的辅助数组
    比如赌神这题

然后扩展到高维:

  • n维平面内从 \((0, 0...0)\)\((x_1, x_2...x_n)\) 的走法数为 \(\frac{(\sum\limits_{i=1}^n x_i)!}{\prod\limits_{i=1}^n x_i!}\)

于是就出来了

Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 1000100
#define ll long long
#define ld long double
#define reg register int
#define fir first
#define sec second
#define make make_pair
//#define int long long 

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n;
int x[N];
const ll mod=998244353;
ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod, b>>=1) if (b&1) ans=ans*a%mod; return ans;}
ld qpow2(ld a, ll b) {ld ans=1; for (; b; a*=a, b>>=1) if (b&1) ans=ans*a; return ans;}
inline ll inv(ll a) {return qpow(a%mod, mod-2);}

namespace task1{
	//unordered_map<pair<int, int>, ll> mp;
	pair<ld, ll> dfs(ll a, ll b) {
		//cout<<"dfs "<<a<<' '<<b<<endl;
		if (a==0 && b==1) return make(n, n);
		if (a==1 && b==0) return make(1.0*n*n, n*n%mod);
		if (a==0) {
			pair<ld, ll> t=dfs(a, b-1);
			return make((t.fir*n)/b, t.sec*n%mod*inv(b)%mod);
		}
		if (b==0) {
			return dfs(a-1, b+1);
		}
		//cout<<"pos1: "<<a<<' '<<b<<endl;
		pair<ld, ll> t1, t2, k1, k2; ld tem1, tem2;
		t1=dfs(a-1, b+1); t2=dfs(a, b-1);
		//cout<<"t: "<<t1.fir<<' '<<t2.fir<<endl;
		//cout<<"t: "<<t1.sec<<' '<<t2.sec<<endl;
		k2.fir=t1.fir/(1.0*a*t2.fir+1.0*b*t1.fir);
		k2.sec=t1.sec*inv(a*t2.sec%mod+b*t1.sec%mod);
		k1.fir=(1.0-1.0*b*k2.fir)/(1.0*a);
		k1.sec=((1ll-b*k2.sec%mod)%mod+mod)%mod*inv(a)%mod;
		//cout<<"k: "<<k1.fir<<' '<<k2.fir<<endl;
		tem1=k1.fir*n*t1.fir; tem2=k2.fir*n*t2.fir;
		//cout<<"return: "<<a<<' '<<b<<' '<<min(tem1, tem2)<<endl;
		if (tem1<tem2) return make(tem1, k1.sec*n%mod*t1.sec%mod);
		else return make(tem2, k2.sec*n%mod*t2.sec%mod);
	}
	void solve() {
		int a=0, b=0;
		for (int i=1; i<=n; ++i)
			if (x[i]==1) ++b;
			else if (x[i]==2) ++a;
		printf("%lld\n", dfs(a, b).sec);
		//cout<<"ans: "<<dfs(a, b).fir<<endl;
		//cout<<dfs(1, 1).sec<<endl;
		exit(0);
	}
}

namespace task2{
	pair<ld, ll> mp[1050][1050];
	bool vis[1050][1050];
	pair<ld, ll> dfs(ll a, ll b) {
		//cout<<"dfs "<<a<<' '<<b<<endl;
		if (a==0 && b==1) return make(n, n);
		if (a==1 && b==0) return make(1.0*n*n, n*n%mod);
		if (vis[a][b]) return mp[a][b];
		if (a==0) {
			pair<ld, ll> t=dfs(a, b-1);
			vis[a][b]=1;
			mp[a][b]=make((t.fir*n)/b, t.sec*n%mod*inv(b)%mod);
			return mp[a][b];
		}
		if (b==0) {
			vis[a][b]=1;
			mp[a][b]=dfs(a-1, b+1);
			return mp[a][b];
		}
		//cout<<"pos1: "<<a<<' '<<b<<endl;
		pair<ld, ll> t1, t2, k1, k2; ld tem1, tem2;
		t1=dfs(a-1, b+1); t2=dfs(a, b-1);
		//cout<<"t: "<<t1.fir<<' '<<t2.fir<<endl;
		//cout<<"t: "<<t1.sec<<' '<<t2.sec<<endl;
		k2.fir=t1.fir/(1.0*a*t2.fir+1.0*b*t1.fir);
		k2.sec=t1.sec*inv(a*t2.sec%mod+b*t1.sec%mod)%mod;
		k1.fir=(1.0-1.0*b*k2.fir)/(1.0*a);
		k1.sec=((1ll-b*k2.sec%mod)%mod+mod)%mod*inv(a)%mod;
		//cout<<"k: "<<k1.fir<<' '<<k2.fir<<endl;
		tem1=k1.fir*n*t1.fir; tem2=k2.fir*n*t2.fir;
		//cout<<"return: "<<a<<' '<<b<<' '<<min(tem1, tem2)<<endl;
		vis[a][b]=1;
		if (tem1<tem2) mp[a][b]=make(tem1, k1.sec*n%mod*t1.sec%mod);
		else mp[a][b]=make(tem2, k2.sec*n%mod*t2.sec%mod);
		return mp[a][b];
	}
	void solve() {
		int a=0, b=0;
		for (int i=1; i<=n; ++i)
			if (x[i]==1) ++b;
			else if (x[i]==2) ++a;
		printf("%lld\n", dfs(a, b).sec);
		//cout<<"ans: "<<dfs(a, b).fir<<endl;
		//cout<<dfs(1, 1).sec<<endl;
		exit(0);
	}
}

namespace task3{
	ll dp[1050][1050];
	void solve() {
		dp[1][0]=dp[0][1]=n;
		for (int i=2; i<=max(x[1], x[2]); ++i) {
			dp[i][0]=dp[i-1][0]*n%mod;
			dp[0][i]=dp[0][i-1]*n%mod;
		}
		for (int i=1; i<=x[1]; ++i) {
			for (int j=1; j<=x[2]; ++j) {
				ll t1=dp[i-1][j], t2=dp[i][j-1];
				dp[i][j] = 1ll*n*t1%mod*t2%mod * inv(t1+t2)%mod;
			}
		}
		printf("%lld\n", dp[x[1]][x[2]]);
		exit(0);
	}
}

namespace task{
	ll fac[N], sum, prod=1;
	void solve() {
		fac[0]=fac[1]=1;
		for (int i=2; i<N; ++i) fac[i]=fac[i-1]*i%mod;
		for (int i=1; i<=n; ++i) sum+=x[i], prod=prod*fac[x[i]]%mod;
		printf("%lld\n", qpow(n, sum)*prod%mod*inv(fac[sum])%mod);
		exit(0);
	}
}

signed main()
{
	n=read();
	for (int i=1; i<=n; ++i) x[i]=read();
	if (n==1) {printf("%lld\n", qpow(n, x[1])); return 0;}
	task::solve();
	
	return 0;
}
posted @ 2021-09-13 18:59  Administrator-09  阅读(3)  评论(0编辑  收藏  举报