[BJOI2018]治疗之雨
我还没疯
发现如果我们将血量抽象成点,一轮操作抽象成图上的一条边,我们如果能求出每一条边的概率,我们就能搞一下这道题
假设我们求出了这个图\(E\),设\(dp_i\)表示从\(i\)点到达\(0\)点的期望路径长度
那么我们可以列出如下的方程
\[dp_u=\sum_{(u,v)\in E}P(u,v)\times(dp_v+1)
\]
发现这个方程可以高斯消元来做
问题变成了如何求出这张图
我们如求出了经过\(k\)次减小的操作,血量\(i\)变成血量\(j\)的概率是多少,我们讨论一下那个增加的操作,就能把这张图求出来了
发现\(k\)很大\(n\)较小,于是觉得可以矩阵优化
优化个鬼啊
设\(dp_{i,j}\)表示经过\(i\)次减小操作血量为\(j\)的概率
显然有
\[dp_{i,j}=\frac{1}{m+1}\times dp_{i-1,j+1}+\frac{m}{m+1}\times dp_{i-1,j}
\]
发现这个柿子确实可以矩阵优化一下,变成\(O(n^3logk)\)
之后套上消元,就有\(40\)分了
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#define maxn 205
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
const LL mod=1000000007;
inline int read() {
int x=0;char c=getchar();while(c<'0'||c>'9') c=getchar();
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
void exgcd(LL a,LL b,LL &x,LL &y) {if(!b) {x=1,y=0;return;}exgcd(b,a%b,y,x);y-=a/b*x;}
inline LL inv(LL t) {LL x,y;exgcd(t,mod,x,y);return (x%mod+mod)%mod;}
int n,p,T;
LL m,k;
LL ans[maxn][maxn],a[maxn][maxn],b[maxn][maxn],E[maxn];
inline void did_ans() {
LL mid[maxn][maxn];
for(re int i=0;i<=n;i++)
for(re int j=0;j<=n;j++) mid[i][j]=ans[i][j],ans[i][j]=0;
for(re int k=0;k<=n;k++)
for(re int i=0;i<=n;i++)
for(re int j=0;j<=n;j++)
ans[i][j]=(ans[i][j]+a[i][k]*mid[k][j]%mod)%mod;
}
inline void did_a() {
LL mid[maxn][maxn];
for(re int i=0;i<=n;i++)
for(re int j=0;j<=n;j++) mid[i][j]=a[i][j],a[i][j]=0;
for(re int k=0;k<=n;k++)
for(re int i=0;i<=n;i++)
for(re int j=0;j<=n;j++)
a[i][j]=(a[i][j]+mid[i][k]*mid[k][j]%mod)%mod;
}
inline void Quick(LL b) {while(b) {if(b&1) did_ans();b>>=1ll;did_a();}}
int main() {
T=read();
while(T--) {
n=read(),p=read(),m=read(),k=read();
memset(b,0,sizeof(b));memset(ans,0,sizeof(ans));
memset(E,0,sizeof(E));memset(a,0,sizeof(a));
if(!p) {puts("0");continue;}
if(!k) {puts("-1");continue;}
if(!m&&k==1) {puts("-1");continue;}
if(!m){
int ans=0;
while(p>0){if(p<n) p++;p-=k;ans++;}
printf("%d\n",ans); continue;
}
LL Inv=inv(m+1ll);
for(re int i=1;i<=n;i++)
a[i][i+1]=Inv,a[i][i]=m*Inv%mod;a[n][n+1]=0;
a[0][0]=1,a[0][1]=Inv;
for(re int i=0;i<=n;i++)
ans[i][i]=1;
Quick(k);
for(re int i=1;i<n;i++) {
for(re int j=1;j<=i;j++)
b[i][j]+=ans[j][i]*Inv%mod*m%mod,b[j][i]%=mod;
for(re int j=1;j<=i+1;j++)
b[i][j]+=ans[j][i+1]*Inv%mod,b[j][i]%=mod;
}
for(re int i=1;i<=n;i++) b[n][i]=ans[i][n];
for(re int i=1;i<=n;i++) b[i][n+1]=mod-1;
for(re int i=1;i<=n;i++) b[i][i]-=1ll,b[i][i]=(b[i][i]+mod)%mod;
for(re int i=1;i<=n;i++) {
LL t=inv(b[i][i]);
for(re int j=n+1;j>=i;--j)
b[i][j]=(b[i][j]*t)%mod;
for(re int j=i+1;j<=n;j++)
for(re int k=n+1;k>=i;--k)
b[j][k]=(b[j][k]-b[j][i]*b[i][k]%mod+mod)%mod;
}
E[n]=b[n][n+1];
for(re int i=n-1;i;--i) {
E[i]=b[i][n+1];
for(re int j=i+1;j<=n;j++)
E[i]=(E[i]-E[j]*b[i][j]%mod+mod)%mod;
}
printf("%lld\n",E[p]);
}
return 0;
}
显然复杂度不对,尤其是矩阵这边
根据高中数学必修三,显然血量减少\(r\)的概率应该是
\[\frac{\binom{k}{r}m^{k-r}}{(m+1)^k}
\]
于是处理一下组合数就不用矩阵了
现在的复杂度就只剩下高斯消元的\(O(n^3)\)了
发现这个矩阵非常特殊,第\(i\)行只有前\(i+1\)列有值,于是我们往下消元的时候一行只需要消三个数就可以了
复杂度\(O(Tn^2)\)
代码
#include<iostream>
#include<cstring>
#include<cstdio>
#define maxn 1505
#define re register
#define LL long long
const LL mod=1000000007;
inline int read() {
int x=0;char c=getchar();while(c<'0'||c>'9') c=getchar();
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
int n,p,T;LL m,k;
LL ans[maxn],b[maxn][maxn],E[maxn],C[maxn];
void exgcd(LL a,LL b,LL &x,LL &y) {if(!b) {x=1,y=0;return;}exgcd(b,a%b,y,x);y-=a/b*x;}
inline LL inv(LL t) {LL x,y;exgcd(t,mod,x,y);return (x%mod+mod)%mod;}
inline LL quick(LL a,LL b) {LL S=1;while(b) {if(b&1ll) S=S*a%mod;b>>=1ll;a=a*a%mod;}return S;}
int main() {
T=read();
while(T--) {
n=read(),p=read(),m=read(),k=read();
memset(b,0,sizeof(b));memset(ans,0,sizeof(ans));
memset(E,0,sizeof(E));memset(C,0,sizeof(C));
if(!p) {puts("0");continue;}
if(!k) {puts("-1");continue;}
if(!m&&k==1) {puts("-1");continue;}
if(!m){
int ans=0;
while(p>0){if(p<n) p++;p-=k;ans++;}
printf("%d\n",ans); continue;
}
LL Inv=inv(m+1ll),D=inv(quick(m+1ll,k));
C[0]=1;LL now=k,fac=1,t=k;C[1]=k;
for(re int i=2;i<=n;i++) t--,fac=(fac*i)%mod,now=(now*t)%mod,C[i]=now*inv(fac)%mod;
for(re int i=0;i<=n&&k>=i;i++) ans[i]=C[i]*quick(m,k-i)%mod*D%mod;
for(re int i=1;i<n;i++) {
for(re int j=1;j<=i;j++)
b[i][j]+=ans[i-j]*Inv%mod*m%mod,b[j][i]%=mod;
for(re int j=1;j<=i+1;j++)
b[i][j]+=ans[i+1-j]*Inv%mod,b[j][i]%=mod;
}
for(re int i=1;i<=n;i++) b[n][i]=ans[n-i];
for(re int i=1;i<=n;i++) b[i][n+1]=mod-1;
for(re int i=1;i<=n;i++) b[i][i]-=1ll,b[i][i]=(b[i][i]+mod)%mod;
for(re int i=1;i<=n;i++) {
LL t=inv(b[i][i]);
for(re int j=n+1;j>=i;--j)
b[i][j]=(b[i][j]*t)%mod;
for(re int j=i+1;j<=n;j++) {
b[j][n+1]=(b[j][n+1]-b[j][i]*b[i][n+1]%mod+mod)%mod;
for(re int k=i+1;k>=i;k--)
b[j][k]=(b[j][k]-b[j][i]*b[i][k]%mod+mod)%mod;
}
}
E[n]=b[n][n+1];
for(re int i=n-1;i;--i) {
E[i]=b[i][n+1];
for(re int j=i+1;j<=n;j++)
E[i]=(E[i]-E[j]*b[i][j]%mod+mod)%mod;
}
printf("%lld\n",E[p]);
}
return 0;
}