题解 HDU6710 Kaguya

题意

题解

以下是官方题解:

不知道你看懂了没,反正我是没看懂。

这里介绍一波 Linshey 大佬的做法。

首先两个随机选点是忽悠你的,其实把两边都固定成 \(1\) 号点也一样,也就是相当于求左边 \(1\) 号点到右边 \(1\) 号点的距离期望。

然后就可以愉快地 3D1D DP 了!

\(f_{n,m,k}\) 表示左边还剩 \(n\) 个点,右边还剩 \(m\) 个点,上一步是右边选了 \(k\) 个点,此时到达右边 \(1\) 号点的期望距离。

\(g_{n,m,k}\) 表示左边还剩 \(n\) 个点,右边还剩 \(m\) 个点,上一步是左边选了 \(k\) 个点,此时到达右边 \(1\) 号点的期望距离。

\(F_{n,m,k}\) 表示左边还剩 \(n\) 个点,右边还剩 \(m\) 个点,上一步是右边选了 \(k\) 个点,此时能到达右边 \(1\) 号点的概率。

\(G_{n,m,k}\) 表示左边还剩 \(n\) 个点,右边还剩 \(m\) 个点,上一步是左边选了 \(k\) 个点,此时能到达右边 \(1\) 号点的概率。

转移方程:

\(F_{n,m,k}=\sum_{i=1}^{n}C_n^i\left(1-\frac{1}{2^k}\right)^i\left(\frac{1}{2^k}\right)^{n-i}G_{n,m-k,i}\)

\(G_{n,m,k}=\left(1-\frac{1}{2^k}\right)+\frac{1}{2^k}\left(\sum_{i=1}^{m-1}C_{m-1}^i\left(1-\frac{1}{2^k}\right)^i\left(\frac{1}{2^k}\right)^{m-1-i}F_{n-k,m,i}\right)\)

\(f_{n,m,k}=\sum_{i=1}^{n}C_n^i\left(1-\frac{1}{2^k}\right)^i\left(\frac{1}{2^k}\right)^{n-i}\left(g_{n,m-k,i}+G_{n,m-k,i}\right)\)

\(g_{n,m,k}=\left(1-\frac{1}{2^k}\right)+\frac{1}{2^k}\left(\sum_{i=1}^{m-1}C_{m-1}^i\left(1-\frac{1}{2^k}\right)^i\left(\frac{1}{2^k}\right)^{m-1-i}\left(f_{n-k,m,i}+F_{n-k,m,i}\right)\right)\)

理解:

  1. 左边还剩 \(n\) 个点,右边还剩 \(m\) 个点,上一步是右边选了 \(k\) 个点,此时假设下一步是左边选 \(i\) 个点,且左边仅有这 \(i\) 个点与右边 \(k\) 个点相连,那么从 \(n\) 个点中选出 \(i\) 个点有 \(C_n^i\) 种,这 \(i\) 个点中每个点与右边 \(k\) 个点都至少有一条连边,有 \(\left(1-\frac{1}{2^k}\right)^i\) 的概率,其他 \(n-i\) 个点与右边 \(k\) 个点都没有连边,有 \(\left(\frac{1}{2^k}\right)^{n-i}\) 的概率,剩下局面能到达右边 \(1\) 号点的概率即为 \(G_{n,m-k,i}\)

  2. 左边还剩 \(n\) 个点,右边还剩 \(m\) 个点,上一步是左边选了 \(k\) 个点,此时左边 \(k\) 个点与右边 \(1\) 号点有连边的概率为 \(\left(1-\frac{1}{2^k}\right)\),没有连边的概率为 \(\frac{1}{2^k}\),此时假设下一步是右边选 \(i\) 个点(且不能选 \(1\) 号点),且右边仅有这 \(i\) 个点与左边 \(k\) 个点相连,那么从 \(m-1\) 个点中选出 \(i\) 个点有 \(C_{m-1}^i\) 种,这 \(i\) 个点中每个点与左边 \(k\) 个点都至少有一条连边,有 \(\left(1-\frac{1}{2^k}\right)^i\) 的概率,其他 \(n-i\) 个点与左边 \(k\) 个点都没有连边,有 \(\left(\frac{1}{2^k}\right)^{n-i}\) 的概率,剩下局面能到达右边 \(1\) 号点的概率即为 \(F_{n-k,m,i}\)

  3. 左边还剩 \(n\) 个点,右边还剩 \(m\) 个点,上一步是右边选了 \(k\) 个点,此时假设下一步是左边选 \(i\) 个点,且左边仅有这 \(i\) 个点与右边 \(k\) 个点相连,那么从 \(n\) 个点中选出 \(i\) 个点有 \(C_n^i\) 种,这 \(i\) 个点中每个点与右边 \(k\) 个点都至少有一条连边,有 \(\left(1-\frac{1}{2^k}\right)^i\) 的概率,其他 \(n-i\) 个点与右边 \(k\) 个点都没有连边,有 \(\left(\frac{1}{2^k}\right)^{n-i}\) 的概率,当前转移对答案贡献为 \(G_{n,m-k,i}\)(因为两点不互通距离视作 \(0\)),剩下局面到达右边 \(1\) 号点的期望距离为 \(g_{n,m-k,i}\)

  4. 左边还剩 \(n\) 个点,右边还剩 \(m\) 个点,上一步是左边选了 \(k\) 个点,此时左边 \(k\) 个点与右边 \(1\) 号点有连边的概率为 \(\left(1-\frac{1}{2^k}\right)\),当前转移对答案贡献为 \(\left(1-\frac{1}{2^k}\right)\),没有连边的概率为 \(\frac{1}{2^k}\),此时假设下一步是右边选 \(i\) 个点(且不能选 \(1\) 号点),且右边仅有这 \(i\) 个点与左边 \(k\) 个点相连,那么从 \(m-1\) 个点中选出 \(i\) 个点有 \(C_{m-1}^i\) 种,这 \(i\) 个点中每个点与左边 \(k\) 个点都至少有一条连边,有 \(\left(1-\frac{1}{2^k}\right)^i\) 的概率,当前转移对答案贡献为 \(F_{n,m-k,i}\)(因为两点不互通距离视作 \(0\)),剩下局面到达右边 \(1\) 号点的期望距离为 \(f_{n,m-k,i}\)

时间复杂度 \(O(Tn^4)\)

于是写出代码:

#include<cstdio>
#include<cstring>
#define Mx 30
#define LL long long
int Test_num,n,m,p;
LL fac[32],inv[32];
LL F[32][32][32],G[32][32][32],f[32][32][32],g[32][32][32];
inline LL Pow(LL a,LL b)
{
	if(!b)return 1;
	if(b==1)return a;
	LL c=Pow(a,(b>>1));
	c=(c*c)%p;
	if(b&1)c=(c*a)%p;
	return c;
}
inline void init()
{
	fac[0]=1;for(int i=1;i<=Mx;++i)fac[i]=(fac[i-1]*i)%p;
	inv[Mx]=Pow(fac[Mx],p-2);for(int i=Mx;i;--i)inv[i-1]=(inv[i]*i)%p;
}
inline LL C(int a,int b)
{
	return (((fac[a]*inv[b])%p)*inv[a-b])%p;
}
inline LL calcG(int n,int m,int k);
inline LL calcF(int n,int m,int k)
{
	if(~F[n][m][k])return F[n][m][k];LL res=0;
	for(int i=1;i<=n;++i)res=(res+((((C(n,i)*Pow(1-Pow(2,1LL*k*(p-2)),i))%p)*Pow(Pow(2,1LL*k*(p-2)),n-i))%p)*calcG(n,m-k,i))%p;
	return F[n][m][k]=(res+p)%p;
}
inline LL calcG(int n,int m,int k)
{
	if(~G[n][m][k])return G[n][m][k];LL res=0;
	for(int i=1;i<m;++i)res=(res+((((C(m-1,i)*Pow(1-Pow(2,1LL*k*(p-2)),i))%p)*Pow(Pow(2,1LL*k*(p-2)),m-1-i))%p)*calcF(n-k,m,i))%p;
	return G[n][m][k]=(((1-Pow(2,1LL*k*(p-2))+Pow(2,1LL*k*(p-2))*res)%p)+p)%p;
}
inline LL calcg(int n,int m,int k);
inline LL calcf(int n,int m,int k)
{
	if(~f[n][m][k])return f[n][m][k];LL res=0;
	for(int i=1;i<=n;++i)res=(res+((((C(n,i)*Pow(1-Pow(2,1LL*k*(p-2)),i))%p)*Pow(Pow(2,1LL*k*(p-2)),n-i))%p)*(calcg(n,m-k,i)+calcG(n,m-k,i)))%p;
	return f[n][m][k]=res;
}
inline LL calcg(int n,int m,int k)
{
	if(~g[n][m][k])return g[n][m][k];LL res=0;
	for(int i=1;i<m;++i)res=(res+((((C(m-1,i)*Pow(1-Pow(2,1LL*k*(p-2)),i))%p)*Pow(Pow(2,1LL*k*(p-2)),m-1-i))%p)*(calcf(n-k,m,i)+calcF(n-k,m,i)))%p;
	return g[n][m][k]=(((1-Pow(2,1LL*k*(p-2))+Pow(2,1LL*k*(p-2))*res)%p)+p)%p;
}
int main()
{
	scanf("%d",&Test_num);
	while(Test_num--)memset(F,-1,sizeof(F)),memset(G,-1,sizeof(G)),memset(f,-1,sizeof(f)),memset(g,-1,sizeof(g)),scanf("%d%d%d",&n,&m,&p),init(),printf("%lld\n",calcg(n,m,1));
	return 0;
}

发现跑的巨慢,原来是递归调用函数占用太多时间,修改成这样:

#include<cstdio>
#define Mx 30
#define LL long long
int Test_num,n,m,p;
LL fac[32],inv[32];
LL F[32][32][32],G[32][32][32],f[32][32][32],g[32][32][32];
inline LL Pow(LL a,LL b)
{
	if(!b)return 1;
	if(b==1)return a;
	LL c=Pow(a,(b>>1));
	c=(c*c)%p;
	if(b&1)c=(c*a)%p;
	return c;
}
inline void init()
{
	fac[0]=1;for(int i=1;i<=Mx;++i)fac[i]=(fac[i-1]*i)%p;
	inv[Mx]=Pow(fac[Mx],p-2);for(int i=Mx;i;--i)inv[i-1]=(inv[i]*i)%p;
}
inline LL C(int a,int b)
{
	return (((fac[a]*inv[b])%p)*inv[a-b])%p;
}
inline void calcF(int n,int m,int k)
{
	LL res=0;
	for(int i=1;i<=n;++i)res=(res+((((C(n,i)*Pow(1-Pow(2,1LL*k*(p-2)),i))%p)*Pow(Pow(2,1LL*k*(p-2)),n-i))%p)*G[n][m-k][i])%p;
	F[n][m][k]=(res+p)%p;
}
inline void calcG(int n,int m,int k)
{
	LL res=0;
	for(int i=1;i<m;++i)res=(res+((((C(m-1,i)*Pow(1-Pow(2,1LL*k*(p-2)),i))%p)*Pow(Pow(2,1LL*k*(p-2)),m-1-i))%p)*F[n-k][m][i])%p;
	G[n][m][k]=(((1-Pow(2,1LL*k*(p-2))+Pow(2,1LL*k*(p-2))*res)%p)+p)%p;
}
inline void calcf(int n,int m,int k)
{
	LL res=0;
	for(int i=1;i<=n;++i)res=(res+((((C(n,i)*Pow(1-Pow(2,1LL*k*(p-2)),i))%p)*Pow(Pow(2,1LL*k*(p-2)),n-i))%p)*(g[n][m-k][i]+G[n][m-k][i]))%p;
	f[n][m][k]=res;
}
inline void calcg(int n,int m,int k)
{
	LL res=0;
	for(int i=1;i<m;++i)res=(res+((((C(m-1,i)*Pow(1-Pow(2,1LL*k*(p-2)),i))%p)*Pow(Pow(2,1LL*k*(p-2)),m-1-i))%p)*(f[n-k][m][i]+F[n-k][m][i]))%p;
	g[n][m][k]=(((1-Pow(2,1LL*k*(p-2))+Pow(2,1LL*k*(p-2))*res)%p)+p)%p;
}
int main()
{
	scanf("%d",&Test_num);
	while(Test_num--)
	{
		scanf("%d%d%d",&n,&m,&p),init();
		for(int i=1;i<=n;++i)
			for(int j=1;j<=m;++j)
			{
				for(int k=1;k<j;++k)calcF(i,j,k),calcf(i,j,k);
				for(int k=1;k<=i;++k)calcG(i,j,k),calcg(i,j,k);
			}
		printf("%lld\n",g[n][m][1]);
	}
	return 0;
}

发现还是不快,仔细观察发现代码时间复杂度其实是 \(O(Tn^4\log p)\) 的,多了个快速幂的 \(\log\),于是可以优化成这样:

#include<cstdio>
#define Mx 30
#define LL long long
int Test_num,n,m,p;
LL fac[32],inv[32];
LL IPw[32][32],IPw1[32][32];
LL F[32][32][32],G[32][32][32],f[32][32][32],g[32][32][32];
inline LL Pow(LL a,LL b)
{
	if(!b)return 1;
	if(b==1)return a;
	LL c=Pow(a,(b>>1));
	c=(c*c)%p;
	if(b&1)c=(c*a)%p;
	return c;
}
inline void init()
{
	LL res,res1;fac[0]=1;for(int i=1;i<=Mx;++i)fac[i]=(fac[i-1]*i)%p;
	inv[Mx]=Pow(fac[Mx],p-2);for(int i=Mx;i;--i)inv[i-1]=(inv[i]*i)%p;
	for(int i=0;i<=Mx;++i)
	{
		res=Pow(2,1LL*i*(p-2)),res1=1;
		for(int j=0;j<=Mx;++j)IPw[i][j]=res1,res1=(res1*res)%p;
	}
	for(int i=0;i<=Mx;++i)
	{
		res=1-Pow(2,1LL*i*(p-2)),res1=1;
		for(int j=0;j<=Mx;++j)IPw1[i][j]=res1,res1=(res1*res)%p;
	}
}
inline LL C(int a,int b)
{
	return (((fac[a]*inv[b])%p)*inv[a-b])%p;
}
inline void calcF(int n,int m,int k)
{
	LL res=0;
	for(int i=1;i<=n;++i)res=(res+((((C(n,i)*IPw1[k][i])%p)*IPw[k][n-i])%p)*G[n][m-k][i])%p;
	F[n][m][k]=(res+p)%p;
}
inline void calcG(int n,int m,int k)
{
	LL res=0;
	for(int i=1;i<m;++i)res=(res+((((C(m-1,i)*IPw1[k][i])%p)*IPw[k][m-1-i])%p)*F[n-k][m][i])%p;
	G[n][m][k]=(((1-Pow(2,1LL*k*(p-2))+Pow(2,1LL*k*(p-2))*res)%p)+p)%p;
}
inline void calcf(int n,int m,int k)
{
	LL res=0;
	for(int i=1;i<=n;++i)res=(res+((((C(n,i)*IPw1[k][i])%p)*IPw[k][n-i])%p)*(g[n][m-k][i]+G[n][m-k][i]))%p;
	f[n][m][k]=res;
}
inline void calcg(int n,int m,int k)
{
	LL res=0;
	for(int i=1;i<m;++i)res=(res+((((C(m-1,i)*IPw1[k][i])%p)*IPw[k][m-1-i])%p)*(f[n-k][m][i]+F[n-k][m][i]))%p;
	g[n][m][k]=(((1-Pow(2,1LL*k*(p-2))+Pow(2,1LL*k*(p-2))*res)%p)+p)%p;
}
int main()
{
	scanf("%d",&Test_num);
	while(Test_num--)
	{
		scanf("%d%d%d",&n,&m,&p),init();
		for(int i=1;i<=n;++i)
			for(int j=1;j<=m;++j)
			{
				for(int k=1;k<j;++k)calcF(i,j,k),calcf(i,j,k);
				for(int k=1;k<=i;++k)calcG(i,j,k),calcg(i,j,k);
			}
		printf("%lld\n",g[n][m][1]);
	}
	return 0;
}

这就是真正的 \(O(Tn^4)\) 正解了!

完结撒花!

posted @ 2021-08-12 17:10  18Michael  阅读(72)  评论(0编辑  收藏  举报