Loj #3045. 「ZJOI2019」开关
Loj #3045. 「ZJOI2019」开关
题目描述
九条可怜是一个贪玩的女孩子。
这天,她和她的好朋友法海哥哥去玩密室逃脱。在他们面前的是 \(n\) 个开关,开始每个开关都是关闭的状态。要通过这关,必须要让开关达到指定的状态。目标状态由一个长度为 \(n\) 的 \(01\) 数组 \(s\) 给出,\(s_i = 0\) 表示第 \(i\) 个开关在最后需要是关着的,\(s_i = 1\) 表示第 \(i\) 个开关在最后需要被打开。
然而作为闯关者,可怜和法海并不知道 \(s\)。因此他们决定采用一个比较稳妥的方法:瞎按。他们根据开关的外形、位置,通过一些玄学的方法给每一个开关赋予了一个值 \(p_i(p_i > 0)\)。每一次,他们会以正比于 \(p_i\) 的概率(第 \(i\) 个开关被选中的概率是 \(\frac{p_i}{\sum_{j=1}^n p_j}\))选择并按下一个开关。开关被按下后,状态会被反转,即开变关,关变开。注意,每一轮的选择都是完全独立的。
在按开关的过程中,一旦当前开关的状态达到了 \(s\),那么可怜和法海面前的门就会打开,他们会马上停止按开关的过程并前往下一关。作为一名欧皇,可怜在按了\(\sum_{i=1}^n s_i\) 次后,就打开了大门。为了感受一下自己的运气是多么的好,可怜想要让你帮她计算一下,用这种随机按的方式,期望需要按多少次开关才能通过这一关。
输入格式
第一行输入一个整数 \(n\),表示开关的数量。
第二行输入 \(n\) 个整数 \(s_i(s_i \in \{0, 1\})\),表示开关的目标状态。第三行同样输入 \(n\) 个整数 \(p_i\),表示每个开关的权值。
输出格式
输出一行一个整数,表示答案对 \(998244353\) 取模后的值。即如果答案的最简分数表示为 \(\frac{x}{y}\ (x \ge 0, y \ge 1, \gcd(x, y) = 1)\),你需要输出 \(x \times y^{−1} \bmod 998244353\)。
数据范围与提示
对于 \(100\%\) 的测试数据,保证 \(n\ge 1,\sum_{i=1}^n p_i\le 5\times 10^4,p_i\ge 1\)。
\(\\\)
完全不会啊,还是\(\text{Orz}\)别人的题解好了。https://www.cnblogs.com/zhoushuyu/p/10687696.html
首先设\(P=\sum_{i=1}^np_i\)
为了方便,下面的\(p_i=\frac{p_i}{P}\)。
先求出走了\(n\)步到达结束状态的概率的\(EGF\):
考虑当\(s_i\)为\(1\)的时候,结束状态时应该按了\(i\)奇数次,否则按了偶数次。
我们也知道:
所以概率的\(EGF\)为:
但是有可能在走完\(n\)步之前就经过了终止状态。全过程可以认为是先走到了终止状态,再绕了若干圈回到终止状态。
考虑走了\(n\)步回到出发点的概率的\(EGF\):
设它们的\(OGF\)分别为\(h,f,g\),则:
下面考虑求\(OGF\)。
我们将\(F\)换一种写法:
则:
答案为\(h'(1)\)(带入一下就知道了):
但是带入\(x=1\)时\(f(x)\)和\(g(x)\)不收敛,于是考虑\(f(x)\)与\(g(x)\)同时乘上\(\prod_{i}(1-\frac{i}{P}x)\)。此时
容易得到:
接下来我们要求\(f'(1)\)。首先我们要利用下面的公式:
证明的话可以将左边部分暴力展开,求导,就会发现等于右边的部分。
对\(f(x)\)求导的话可以分为两部分:
1.首先是\(i\neq P\)的部分:
因为我最后要的是\(f'(1)\)啊,所以将\(x=1\)带入就会发现,当\(k\neq P\)的时候后面的\(\prod_{j\neq i,j\neq k}(1-\frac{j}{P}x)=0\)。所以:
2.然后是\(i=P\)的部分:
\(\\\)
于是
带入\(h'(1)=\frac{f'(1)g(1)-f(1)g'(1)}{g^2(1)}\)就可以了。
设\(f(x)=\sum a_i \prod_{j\neq i}(1-\frac{j}{P}x),g(x)=\sum b_i \prod_{j\neq i}(1-\frac{j}{P}x)\):
又因为\(a_P=b_P=\frac{1}{2^n}\),所以
下面考虑计算\(a_i,b_i\)。
以\(G(x)\)为例:
可以发现这就是个背包。
代码:
#include<bits/stdc++.h>
#define ll long long
#define N 50005
using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}
const ll mod=998244353;
ll ksm(ll t,ll x) {
ll ans=1;
for(;x;x>>=1,t=t*t%mod)
if(x&1) ans=ans*t%mod;
return ans;
}
int n;
ll sum,invs;
int s[N],p[N];
ll f[N<<1],g[N<<1];
ll tem[N<<1];
const ll inv2=mod+1>>1;
int main() {
n=Get();
for(int i=1;i<=n;i++) s[i]=Get();
for(int i=1;i<=n;i++) p[i]=Get();
f[N]=g[N]=1;
for(int i=1;i<=n;i++) {
sum+=p[i];
ll flag=s[i]?mod-inv2:inv2;
memset(tem,0,sizeof(tem));
for(int j=sum;j>=p[i]-sum;j--) tem[N+j]=f[N+j-p[i]]*inv2%mod;
for(int j=-sum;j<=sum-p[i];j++) (tem[N+j]+=f[N+j+p[i]]*flag)%=mod;
memcpy(f,tem,sizeof(f));
memset(tem,0,sizeof(tem));
for(int j=sum;j>=p[i]-sum;j--) tem[N+j]=g[N+j-p[i]]*inv2%mod;
for(int j=-sum;j<=sum-p[i];j++) (tem[N+j]+=g[N+j+p[i]]*inv2)%=mod;
memcpy(g,tem,sizeof(g));
}
ll ans=0;
for(int i=-sum;i<sum;i++) (ans+=sum*ksm(sum-i,mod-2)%mod*(mod-f[N+i]+g[N+i]))%=mod;
cout<<ans*ksm(2,n)%mod;
return 0;
}