luogu P5326 [ZJOI2019]开关
直接优化高斯消元似乎不是很可做,可以换一种思路。
我们先来考虑一个弱化版的问题:求某个时刻为当前局面的答案。
这个东西长得一脸指数生成函数,不妨列出来,设\(F_i\)为\(i\)号灯的生成函数,其中\(F_i[x_j]\)表示\(i\)号灯被操作\(j\)次的方案。
先来考虑\(s_i=0\),不妨设\(P=\sum p_i\)则
\[F_i=\sum\limits_{i=0}^{\infty}{\frac{(\frac{p_i}{P}x)^{2i}}{(2i)!}}=\frac{e^{\frac{p_i}{P}x}+e^{-\frac{p_i}{P}x}}{2}
\]
\(s_i=1\)的情况类似,可以整合为
\[F_i=\frac{e^{\frac{p_i}{P}x}+(-1)^{s_i}e^{-\frac{p_i}{P}x}}{2}
\]
那么这个弱化版答案的生成函数就是所有\(F_i\)乘起来。
再考虑这个问题,即第一次得到答案的时刻。我们发现如果再求出生成函数\(H\)表示某个时刻为全\(0\)的概率,记答案的生成函数为\(G\),则有\(GH=F\),具体原因就是考虑从第一次不断重复\(0\)的过程。因此我们求的就是\(\frac{F}{H}\)。
再考虑答案,容易发现就是\((\frac{F}{H})'(1)=\frac{F'(1)H(1)-F(1)H'(1)}{H(1)^2}\)。
先暴力\(O(nP)\)背包求出\(F\)每一项的系数设为\(a\),然后你发现这个无穷的还带\(e\)似乎不好求某个点的值。
先考虑去掉\(e\),将其化成普通生成函数:
\[F=\sum\limits_{i=-P}^P{a_ie^{\frac{i}{P}x}}\\
=\sum\limits_{i=-P}^P{a_i\sum\limits_{j=0}^{\infty}{\frac{i^j}{P^jj!}x^j}}\\
f=\sum\limits_{i=-P}^P{a_i\sum\limits_{j=0}^{\infty}{(\frac{i}{P})^jx^j}}\\
=\sum\limits_{i=-P}^P{\frac{a_i}{1-\frac{i}{P}x}}\\
\]
现在项数是有限了,但是在\(x=1\)处的取值还是没有……
因为这个是除法,所以可以上下同时乘上\((1-x)\),有
\[(1-x)f=\sum\limits_{i=-P}^P{\frac{a_i(1-x)}{1-\frac{i}{P}x}}\\
=a_P+\sum\limits_{i=-P}^{P-1}{\frac{a_i(1-x)}{1-\frac{i}{P}x}}\\
(1-x)f(1)=a_P\\
((1-x)f)'(1)=\sum\limits_{i=-P}^{P-1}\frac{a_i}{\frac{i}{P}-1}
\]
然后再对\(H\)做同样的事情就可以算答案了。
#include<bits/stdc++.h>
#define Gc() getchar()
#define Me(x,y) memset(x,y,sizeof(x))
#define Mc(x,y) memcpy(x,y,sizeof(x))
#define d(x,y) ((m)*(x-1)+(y))
#define R(n) (rnd()%(n)+1)
#define Pc(x) putchar(x)
#define LB lower_bound
#define UB upper_bound
#define PB push_back
using ll=long long;using db=double;using lb=long db;using ui=unsigned;using ull=unsigned ll;
using namespace std;const int N=1e2+5,M=1e5+5,K=2e3+5,mod=998244353,Mod=mod-1;const db eps=1e-9;const int INF=1e9+7;mt19937 rnd(time(0));
int n,m,k,x,y,z,A[N],B[N],P,C[N];ll f[M],g[M];
ll mpow(ll x,int y=mod-2){ll Ans=1;while(y) y&1&&(Ans=Ans*x%mod),y>>=1,x=x*x%mod;return Ans;}
#define fi first
#define se second
pair<ll,ll> calc(int *A){
int i,j;Me(f,0);f[P]=1;for(i=1;i<=n;i++){
Mc(g,f);Me(f,0);for(j=B[i];j<=2*P;j++) f[j]=g[j-B[i]];
for(j=0;j<=2*P-B[i];j++) f[j]=(f[j]+(A[i]?mod-g[j+B[i]]:g[j+B[i]]))%mod;
}
pair<ll,ll> Ans;Ans.fi=f[2*P];for(i=0;i<2*P;i++) Ans.se+=f[i]*P%mod*mpow(mod-2*P+i)%mod;Ans.se%=mod;return Ans;
}
int main(){
freopen("switch.in","r",stdin);freopen("switch.out","w",stdout);
int i,j;scanf("%d",&n);for(i=1;i<=n;i++) scanf("%d",&A[i]);for(i=1;i<=n;i++) scanf("%d",&B[i]),P+=B[i];
auto p1=calc(A),p2=calc(C);
printf("%lld\n",(p1.se*p2.fi+(mod-p1.fi)*p2.se)%mod*mpow(p2.fi*p2.fi%mod));
}