[BZOJ5292][BJOI2018]治疗之雨(概率DP+高斯消元)
https://blog.csdn.net/xyz32768/article/details/83217209
不难找到DP方程与辅助DP方程,发现DP方程具有后效性,于是高斯消元即可。
但朴素消元显然无法通过,注意到f[i]的方程至多与f[i+1]有关,于是从下往上依次消去最后一个数,剩下的就是一个下三角,直接求解即可。
注意中间与指数有关的计算能预处理的就不用快速幂,以及阶乘等值可以在程序开头预处理。
复杂度$O(n^2)$,不知道为什么和别人的代码相比常数巨大。
1 #include<cstdio> 2 #include<algorithm> 3 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 4 using namespace std; 5 6 const int N=3010,mod=1e9+7; 7 int n,m,p,k,T,d[N],pw[N],fac[N],inv[N],P[N][N],a[N][N]; 8 9 int ksm(int a,int b){ 10 int res=1; 11 for (; b; a=1ll*a*a%mod,b>>=1) 12 if (b & 1) res=1ll*res*a%mod; 13 return res; 14 } 15 16 bool Gauss(){ 17 for (int i=n; i; i--){ 18 if (!a[i][i]) return 0; 19 int t=1ll*a[i-1][i]*ksm(a[i][i],mod-2)%mod; 20 rep(j,0,i) a[i-1][j]=(a[i-1][j]-1ll*t*a[i][j]%mod+mod)%mod; 21 a[i-1][n+1]=(a[i-1][n+1]-1ll*t*a[i][n+1]%mod+mod)%mod; 22 } 23 rep(i,1,n){ 24 rep(j,0,i-1) a[i][n+1]=(a[i][n+1]-1ll*a[i][j]*a[j][n+1]%mod+mod)%mod; 25 a[i][n+1]=1ll*a[i][n+1]*ksm(a[i][i],mod-2)%mod; 26 } 27 return 1; 28 } 29 30 int main(){ 31 freopen("heal.in","r",stdin); 32 freopen("heal.out","w",stdout); 33 n=1500; 34 fac[0]=1; rep(i,1,n) fac[i]=1ll*fac[i-1]*i%mod; 35 inv[n]=ksm(fac[n],mod-2); 36 for (int i=n-1; ~i; i--) inv[i]=1ll*inv[i+1]*(i+1)%mod; 37 for (scanf("%d",&T); T--; ){ 38 scanf("%d%d%d%d",&n,&p,&m,&k); 39 rep(i,0,n+1) rep(j,0,n+1) a[i][j]=P[i][j]=0; 40 d[0]=1; rep(i,1,min(n,k)) d[i]=1ll*d[i-1]*(k-i+1)%mod; 41 pw[0]=ksm(m,k); int t=ksm(m,mod-2); 42 rep(i,1,min(n,k)) pw[i]=1ll*pw[i-1]*t%mod; 43 if (k<=n) pw[k]=1; 44 if (!k || (!m && k==1)){ puts("-1"); continue; } 45 t=ksm(ksm(m+1,k),mod-2); 46 rep(i,1,n){ 47 int sm=0; 48 rep(j,0,min(i,k)) 49 P[i][j]=(i==j)?(1-sm+mod)%mod:1ll*d[j]*inv[j]%mod*pw[j]%mod*t%mod,sm+=P[i][j]; 50 } 51 a[0][0]=1; int inv=ksm(m+1,mod-2); 52 rep(i,1,n-1){ 53 a[i][n+1]=a[i][i]=mod-1; 54 rep(j,0,i+1){ 55 a[i][i-j+1]=(a[i][i-j+1]+1ll*P[i+1][j]*inv)%mod; 56 a[i][i-j]=(a[i][i-j]+1ll*P[i][j]*inv%mod*m)%mod; 57 } 58 } 59 a[n][n+1]=a[n][n]=mod-1; 60 rep(j,0,n) a[n][n-j]=(a[n][n-j]+P[n][j])%mod; 61 if (Gauss()) printf("%d\n",a[p][n+1]); else puts("-1"); 62 } 63 return 0; 64 }