洛谷 P5326 [ZJOI2019]开关
洛谷 P5326 [ZJOI2019]开关
https://www.luogu.com.cn/problem/P5326
Tutorial
https://www.luogu.com.cn/blog/xht37/solution-p5326
https://www.cnblogs.com/PinkRabbit/p/ZJOI2019D2T1.html
令\(p_i=\dfrac {p_i}{\sum p_i}\)
设\(f(x)\)表示在第\(k\)步到达合法状态的概率的生成函数,因为只关心第一次到达合法状态的情况,所以设\(g(x)\)表示走\(k\)步后回到原来的状态的概率,\(h(x)\)表示第\(k\)步第一次走到合法状态的概率,则有\(f(x)=g(x)h(x) \to h(x)=\dfrac{f(x)}{g(x)}\) .设\(h(x)=\sum a_k x^k\),则我们要求就是
\[\sum ka_k=h'(1)=\dfrac{f'(1)g(1)-f(1)g'(1)}{g^2(1)}
\]
考虑如何求\(f(x)\).到达合法状态的条件为选择开关\(i\)的次数与\(s_i\)相等.则有
\[F_i(x)=\dfrac{e^{p_ix}+(-1)^{s_i}e^{-p_ix}}2
\]
发现\(f(x)\)是OGF,\(F_i(x)\)为EGF,为了相互转化,将\(\prod F_i(x)\)表示为\(\sum c_k(e^x)^k\)的形式,其中\(c_k\)可以用背包在\(O(n\sum p)\)的时间求得,最后得到
\[\begin{align}
f(x)&=\sum_k ([x^k]k!\sum_i c_i(e^x)^i)x^k \\
&=\sum_k(k!\sum_i c_i [x^k](e^x)^i)x^k \\
&=\sum_k(k!\sum_ic_i\dfrac{i^k}{k!})x^k \\
&=\sum_k (\sum_i c_ii^k)x^k \\
&=\sum_ic_i\sum_{k}i^kx^k \\
&=\sum_i\dfrac{c_i}{1-ix}
\end{align}
\]
\(g(x)\)的处理类似,最后得到
\[g(x)=\sum_i\dfrac{d_i}{1-ix}
\]
但是发现当\(i=1\)时会有\(1-x\)这一项,所以不能直接将\(x=1\)带入,考虑分子分母同乘\((1-x)\),得到新的\(f(x),g(x)\)
\[f(x)=\sum_i\dfrac{c_i(1-x)}{1-ix}=c_1+\sum_{i\not=1}\dfrac{c_i(1-x)}{1-ix}
\]
所以此时\(f(1)=c_1\)
\[f'(x)=\sum_{i\not=1}\dfrac{c_i(ix-1)+ic_i(1-x)}{(1-ix)^2} \\
f'(1)=\sum_{i\not=1}\dfrac{c_i(i-1)}{(1-i)^2}=\sum_{i\not=1}\dfrac{c_i}{i-1}
\]
\(g(1),g'(1)\)也类似计算,即可得到答案.
Code
#include <cstdio>
#include <cstring>
#include <iostream>
#define debug(...) fprintf(stderr,__VA_ARGS__)
#define inver(a) power(a,mod-2)
using namespace std;
inline char gc() {
// return getchar();
static char buf[100000],*l=buf,*r=buf;
return l==r&&(r=(l=buf)+fread(buf,1,100000,stdin),l==r)?EOF:*l++;
}
template<class T> void rd(T &x) {
x=0; int f=1,ch=gc();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=gc();}
while(ch>='0'&&ch<='9'){x=x*10-'0'+ch;ch=gc();}
x*=f;
}
typedef long long ll;
const int mod=998244353,r2=(mod+1)>>1;
const int maxn=100+5,maxP=1e5+50;
int n,P,s[maxn],p[maxn];
int c[2][maxP],d[2][maxP];
inline int sub(int x) {return x<0?x+mod:x;}
ll power(ll x,ll y) {
ll re=1;
while(y) {
if(y&1) re=re*x%mod;
x=x*x%mod;
y>>=1;
}
return re;
}
inline int sqr(int x) {return (ll)x*x%mod;}
inline void upd(int *a,int *b,int v,int w) {
for(int i=0;i<=(P<<1);++i) if(b[i]) {
a[i+w]=(a[i+w]+(ll)v*b[i])%mod;
}
}
int main() {
rd(n);
for(int i=1;i<=n;++i) rd(s[i]);
for(int i=1;i<=n;++i) rd(p[i]),P+=p[i];
int cur=0;
c[cur][P]=d[cur][P]=1;
for(int i=1;i<=n;++i) {
cur^=1;
memset(c[cur],0,sizeof(c[cur])),memset(d[cur],0,sizeof(d[cur]));
upd(c[cur],c[cur^1],r2,p[i]),upd(c[cur],c[cur^1],(ll)r2*(s[i]==1?mod-1:1)%mod,-p[i]);
upd(d[cur],d[cur^1],r2,p[i]),upd(d[cur],d[cur^1],r2,-p[i]);
}
int an=0,c1=c[cur][P<<1],d1=d[cur][P<<1],t=inver(P);
for(int i=-P;i<P;++i) {
an=(an+inver(sub((ll)i*t%mod-1))*sub((ll)c[cur][i+P]*d1%mod-(ll)c1*d[cur][i+P]%mod))%mod;
}
an=(ll)an*sqr(inver(d1))%mod;
printf("%d\n",an);
return 0;
}