[BJOI2018]治疗之雨
XVI.[BJOI2018]治疗之雨
一眼能看出这是道高斯消元题。
我们设\(f_i\)表示当前英雄血量为\(i\)时期望多少次死掉。
则我们有
\[f_i=\dfrac{1}{m+1}\times\Big(\sum\limits_{j=0}^iq_jf_{i+1-j}\Big)+\dfrac{m}{m+1}\times\Big(\sum\limits_{j=0}^{i-1}q_jf_{i-j}\Big)+1
\]
其中\(q_j\)是打掉\(j\)点血的概率,有
\[q_j=\Big(\dfrac{1}{m+1}\Big)^j\Big(\dfrac{m}{m+1}\Big)^{k-j}\dbinom{k}{j}
\]
特殊地,当满血(即\(i=n\))时,因为必定奶不到英雄,故直接有
\[f_i=\sum\limits_{j=0}^{i-1}q_jf_{i-j}+1
\]
明显这个DP有后效性,必须得高斯消元;但是\(O(n^3)\)你告诉我能跑\(1500\)?
抱歉,还真能。
我们列出转移矩阵,发现它大概长这样:
\[\begin{bmatrix}&?,?,0,0,0,\dots,0\\&?,?,?,0,0,\dots,0\\&?,?,?,?,0,\dots,0\\&\dots\\&?,?,?,?,?,\dots,?\end{bmatrix}
\]
发现它十分接近倒着的上三角矩阵;故我们只需要消掉对角线上面的一行,就能把它消成一个真正的倒置上三角矩阵,然后就能\(O(n^2)\)代回解出了。
明显只消掉对角线上面一行的复杂度也是\(O(n^2)\)的;故总复杂度即为\(O(n^2)\)。
代码:
#include<bits/stdc++.h>
using namespace std;
const int mod=1e9+7;
int T,n,m,S,p,g[2010][2010],q[2010],inv[2010],f[2010];
int ksm(int x,int y){
int z=1;
for(;y;y>>=1,x=1ll*x*x%mod)if(y&1)z=1ll*x*z%mod;
return z;
}
void Gauss(){
for(int i=n;i>1;i--){
if(!g[i][i])swap(g[i],g[i-1]);
int tmp=1ll*g[i-1][i]*ksm(g[i][i],mod-2)%mod;
for(int j=1;j<=i;j++)(g[i-1][j]+=mod-1ll*g[i][j]*tmp%mod)%=mod;
(g[i-1][n+1]+=mod-1ll*g[i][n+1]*tmp%mod)%=mod;
}
// for(int i=1;i<=n;i++){for(int j=1;j<=n+1;j++)printf("%d ",g[i][j]);puts("");}
f[1]=1ll*g[1][n+1]*ksm(g[1][1],mod-2)%mod;
for(int i=2;i<=n;i++){
f[i]=g[i][n+1];
for(int j=1;j<i;j++)(f[i]+=mod-1ll*g[i][j]*f[j]%mod)%=mod;
f[i]=1ll*f[i]*ksm(g[i][i],mod-2)%mod;
}
}
int main(){
scanf("%d",&T);
for(int i=1;i<=2000;i++)inv[i]=ksm(i,mod-2);
while(T--){
scanf("%d%d%d%d",&n,&S,&m,&p);
if(p==0||p==1&&m==0&&(S!=n||n!=1)){puts("-1");continue;}
int P=ksm(m+1,mod-2),Q=1ll*m*P%mod;
for(int i=0;i<=n;i++)q[i]=0;
for(int i=0,j=1;i<=min(n,p);j=1ll*j*(p-i)%mod,i++,j=1ll*j*inv[i]%mod)q[i]=1ll*ksm(P,i)*ksm(Q,p-i)%mod*j%mod;
for(int i=1;i<=n;i++)for(int j=1;j<=n+1;j++)g[i][j]=0;
for(int i=1;i<=n;i++){
for(int j=0;j<i;j++)g[i][i-j]=1ll*q[j]*(i<n?mod+1-P:1)%mod;
if(i<n)for(int j=0;j<i+1;j++)(g[i][i+1-j]+=1ll*q[j]*P%mod)%=mod;
(g[i][i]+=mod-1)%=mod;
g[i][n+1]=mod-1;
}
// for(int i=1;i<=n;i++){for(int j=1;j<=n+1;j++)printf("%d ",g[i][j]);puts("");}
Gauss();
printf("%d\n",f[S]);
}
return 0;
}