【BZOJ4944】【NOI2017】泳池 概率DP 常系数线性递推 特征多项式 多项式取模

题目大意

  有一个\(1001\times n\)的的网格,每个格子有\(q\)的概率是安全的,\(1-q\)的概率是危险的。

  定义一个矩形是合法的当且仅当:

  • 这个矩形中每个格子都是安全的
  • 必须紧贴网格的下边界

  问你最大的合法子矩形大小为\(k\)的概率是多少。

  \(n\leq {10}^9,k\leq 1000\)

  吉老师:这题本来是\(k\leq 20000\)

题解

  一道好题。

  我们计算最大子矩形不超过\(i\)的答案\(s_i\),那么答案就是\(s_k-s_{k-1}\)

  显然最后一行连续的安全格子不会超过\(k\)个。

  设\(g_{i,j}\)表示长度为\(j\),高度为\(i\)的海域全部是安全的,剩下的部分未知,最大子矩形\(\leq k\)的概率。

  设\(h_{i,j}\)表示长度为\(j\),高度为\(i+1\)的海域中,前\(i\)行全部是安全的,剩下的未知且\((i+1,j)\)是危险的,最大子矩形\(\leq k\)的概率。

  边界:

\[\begin{align} g_{k,1}&=q^k(1-q)\\ g_{i,0}&=1\\ h_{i,0}&=1 \end{align} \]

  那么我们从\(k-1\)\(1\)DP,对于\(i\)\(j\)列,枚举第\(i+1\)行的下一个危险的格子在哪个地方,然后转移:

\[\begin{align} g_{i,j}&=\sum_{k=0}^{j}h_{i,k}g_{i+1,j-k}\\ h_{i,j}&=\sum_{k=0}^{j-1}h_{i,k}g_{i+1,j-k-1}q^i(1-q) \end{align} \]

  因为第\(i\)行的宽度不会超过\(\lfloor\frac{k}{i}\rfloor\),所以的暴力的时间复杂度是\(\sum_{i=1}^k{\lfloor\frac{k}{i}\rfloor}^2=O(k^2)\)

  这已经足够了,但我们可以做的更好。

  设

\[\begin{align} A_i(x)&=\sum_{j\geq 0}g_{i,j}x^j\\ B_i(x)&=\sum_{j\geq 0}h_{i,j}x^j\\ c_i&=q^i(1-q)\\ \end{align} \]

那么

\[\begin{align} A_i(x)&=B_i(x)A_{i+1}(x)\\ B_i(x)&=c_ixA_{i+1}(x)B_i(x)+1\\ B_i(x)&=\frac{1}{1-c_ixA_{i+1}(x)}\\ \end{align} \]

  时间复杂度是\(\sum_{i=1}^k\lfloor\frac{k}{i}\rfloor\log\lfloor\frac{k}{i}\rfloor=O(k\log^2k)\)

  设\(f_i\)为前\(i\)列最大子矩形\(\leq k\)的概率,那么

\[f_i=\sum_{j=1}^kf_{i-j-1}g_{1,j}(1-q) \]

  这就是一个常系数线性递推。

\[\begin{align} a_i&=g_{1,i-1}(1-q)\\ f_i&=\sum_{j=1}^kf_{i-j}a_j \end{align} \]

  时间复杂度:

  • 暴力:\(O(nk)\)\(70\)pts
  • 矩阵快速幂:\(O(k^3\log n)\)\(90\)pts
  • 特征多项式+暴力:\(O(k^2\log n)\)\(100\)pts
  • 特征多项式+NTT取模:\(O(k\log k\log n)\)\(100\)pts

  这里简单讲一下最后一个做法

  矩阵快速幂是给你一个矩阵\(A\),求\((A^n)_{1,1}\)

  设矩阵的大小为\(k\)

  根据Cayley-Hamilton定理,\(|\lambda I-A|\)是一个关于\(\lambda\)\(k\)次多项式,记为\(g(\lambda)\)。对于任意矩阵\(A\),有\(g(A)=0\)

  对于常系数线性递推的矩阵,设\(f_i=\sum_{j=1}^kf_{i-j}a_j\)\(g(\lambda)=\lambda^k-\sum_{i=1}^{k}a_{i}\lambda^{k-i}\)

  所以我们只需要求\(A^n\mod g(A)\)。可以用快速幂(倍增取模)求解。

  然后还要求出\(f_1\ldots f_k\),可以通过其他方法计算(多项式求逆或者题目给你了)。

  最后一次卷积可以得到答案。

  如果要求\(f_{n-k+1}\ldots f_n\),那就把\(f_1\ldots f_{2k}\)带进去卷积。

  总时间复杂度:\(O(k\log^2k+k\log k\log n)\)

代码

  暴力取模

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<cmath>
#include<functional>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
void sort(int &a,int &b)
{
	if(a>b)
		swap(a,b);
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
	char str[100];
	sprintf(str,"%s.in",s);
	freopen(str,"r",stdin);
	sprintf(str,"%s.out",s);
	freopen(str,"w",stdout);
#endif
}
int rd()
{
	int s=0,c;
	while((c=getchar())<'0'||c>'9');
	do
	{
		s=s*10+c-'0';
	}
	while((c=getchar())>='0'&&c<='9');
	return s;
}
int upmin(int &a,int b)
{
	if(b<a)
	{
		a=b;
		return 1;
	}
	return 0;
}
int upmax(int &a,int b)
{
	if(b>a)
	{
		a=b;
		return 1;
	}
	return 0;
}
ll p=998244353;
void add(ll &a,ll b)
{
	a=(a+b)%p;
}
ll fp(ll a,ll b)
{
	ll s=1;
	for(;b;b>>=1,a=a*a%p)
		if(b&1)
			s=s*a%p;
	return s;
}
ll inv(ll a)
{
	return fp(a,p-2);
}
ll pw1[1010];
ll pw2[1010];
ll q;
ll q2;
ll g[1010][1010];
ll h[1010][1010];
ll f[2010];
ll a[2010];
ll c[2010];
ll d[2010];
ll final[2010];
void mul(ll *a,ll *b,ll *e,int len)
{
	static ll c[2010];
	int i,j;
	for(i=0;i<=2*len;i++)
		c[i]=0;
	for(i=0;i<=len;i++)
		for(j=0;j<=len;j++)
			add(c[i+j],a[i]*b[j]);
	for(i=2*len;i>=len;i--)
	{
		ll v=c[i]*inv(e[len]);
		if(v)
			for(j=0;j<=len;j++)
				c[i-len+j]=(c[i-len+j]-e[j]*v)%p;
	}
	for(i=0;i<=len;i++)
		a[i]=c[i];
}
ll solve(int n,int k)
{
	if(!k)
		return fp(q2,n);
	memset(g,0,sizeof g);
	memset(h,0,sizeof h);
	g[k][1]=q2*pw1[k]%p;
	g[k][0]=1;
	int i,j,l;
	for(i=k-1;i>=1;i--)
	{
		int m=k/i;
		g[i][0]=1;
		h[i][0]=1;
		for(j=0;j<=m;j++)
		{
			for(l=j+1;l<=m;l++)
				add(h[i][l],h[i][j]*g[i+1][l-j-1]%p*q2%p*pw1[i]%p);
			for(l=j;l<=m;l++)
				if(l)
					add(g[i][l],h[i][j]*g[i+1][l-j]%p);
		}
	}
	memset(f,0,sizeof f);
	f[0]=1;
	for(i=1;i<=2*(k+1);i++)
		for(j=0;j<i&&j<=k;j++)
			add(f[i],f[i-j-1]*q2%p*g[1][j]);
	if(n<=2*(k+1))
	{
		ll s=0;
		for(i=0;i<=n&&i<=k;i++)
			add(s,f[n-i]*g[1][i]);
		return s;
	}
	int len=k+1;
	for(i=0;i<len;i++)
		a[i]=-q2*g[1][len-i-1]%p;
	a[len]=1;
	memset(c,0,sizeof c);
	c[1]=1;
	memset(d,0,sizeof d);
	d[0]=1;
	int m=n-k-1;
	while(m)
	{
		if(m&1)
			mul(d,c,a,len);
		mul(c,c,a,len);
		m>>=1;
	}
	memset(final,0,sizeof final);
	for(i=1;i<=k+1;i++)
		for(j=0;j<=k;j++)
			add(final[i],d[j]*f[i+j]);
	ll s=0;
	for(i=1;i<=k+1;i++)
		add(s,final[i]*g[1][k+1-i]);
	return s;
}
int main()
{
	open("bzoj4944");
	int n,k,x,y;
	scanf("%d%d%d%d",&n,&k,&x,&y);
	q=x*inv(y)%p;
	q2=(y-x)*inv(y)%p;
	pw1[0]=pw2[0]=1;
	int i;
	for(i=1;i<=k;i++)
	{
		pw1[i]=pw1[i-1]*q%p;
		pw2[i]=pw2[i-1]*q2%p;
	}
	ll ans1=solve(n,k);
	ll ans2=solve(n,k-1);
	ll ans=((ans1-ans2)%p+p)%p;
	printf("%lld\n",ans);
	return 0;
}

  NTT取模

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<cmath>
#include<functional>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
void sort(int &a,int &b)
{
	if(a>b)
		swap(a,b);
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
	char str[100];
	sprintf(str,"%s.in",s);
	freopen(str,"r",stdin);
	sprintf(str,"%s.out",s);
	freopen(str,"w",stdout);
#endif
}
int rd()
{
	int s=0,c;
	while((c=getchar())<'0'||c>'9');
	do
	{
		s=s*10+c-'0';
	}
	while((c=getchar())>='0'&&c<='9');
	return s;
}
int upmin(int &a,int b)
{
	if(b<a)
	{
		a=b;
		return 1;
	}
	return 0;
}
int upmax(int &a,int b)
{
	if(b>a)
	{
		a=b;
		return 1;
	}
	return 0;
}
const ll p=998244353;
const int maxn=300000;
ll fp(ll a,ll b)
{
	ll s=1;
	for(;b;b>>=1,a=a*a%p)
		if(b&1)
			s=s*a%p;
	return s;
}
namespace ntt
{
	const ll g=3;
    ll w1[maxn];
    ll w2[maxn];
    int rev[maxn];
    int n;
    void init(int m)
    {
        n=1;
        while(n<m)
            n<<=1;
        int i;
        for(i=2;i<=n;i<<=1)
        {
            w1[i]=fp(g,(p-1)/i);
            w2[i]=fp(w1[i],p-2);
        }
        rev[0]=0;
        for(i=1;i<n;i++)
            rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));
    }
    void ntt(ll *a,int t)
    {
        int i,j,k;
        ll u,v,w,wn;
        for(i=0;i<n;i++)
            if(rev[i]<i)
                swap(a[i],a[rev[i]]);
        for(i=2;i<=n;i<<=1)
        {
            wn=(t==1?w1[i]:w2[i]);
            for(j=0;j<n;j+=i)
            {
                w=1;
                for(k=j;k<j+i/2;k++)
                {
                    u=a[k];
                    v=a[k+i/2]*w%p;
					a[k]=(u+v)%p;
					a[k+i/2]=(u-v)%p;
                    w=w*wn%p;
                }
            }
        }
        if(t==-1)
        {
            u=fp(n,p-2);    
            for(i=0;i<n;i++)
                a[i]=a[i]*u%p;
        }
    }
    ll x[maxn];
    ll y[maxn];
    ll z[maxn];
    void copy_clear(ll *a,ll *b,int m)
    {
        int i;
        for(i=0;i<m;i++)
            a[i]=b[i];
        for(i=m;i<n;i++)
            a[i]=0;
    }
    void copy(ll *a,ll *b,int m)
    {
        int i;
        for(i=0;i<m;i++)
            a[i]=b[i];
    }
    void mul(ll *a,ll *b,ll *c,int m)
    {
    	init(m<<1);
    	copy_clear(x,a,m);
    	copy_clear(y,b,m);
    	ntt(x,1);
    	ntt(y,1);
    	int i;
    	for(i=0;i<n;i++)
    		x[i]=x[i]*y[i]%p;
    	ntt(x,-1);
    	copy(c,x,m);
    }
    void inverse(ll *a,ll *b,int m)
    {
        if(m==1)
        {
            b[0]=fp(a[0],p-2);
            return;
        }
        inverse(a,b,m>>1);
        init(m<<1);
        copy_clear(x,a,m);
        copy_clear(y,b,m>>1);
        ntt(x,1);
        ntt(y,1);
        int i;
        for(i=0;i<n;i++)
            x[i]=y[i]*(2-x[i]*y[i]%p)%p;
    	ntt(x,-1);
    	copy(b,x,m);
    }
    ll c[maxn],d[maxn],e[maxn],f[maxn];
    void sqrt(ll *a,ll *b,int m)
    {
    	if(m==1)
    	{
    		if(a[0]==1)
    			b[0]=1;
    		else if(a[0]==0)
    			b[0]=0;
    		else
    			//我也不会
				;
			return;
		}
		sqrt(a,b,m>>1);
//		copy_clear(c,b,m>>1);
		int i;
		for(i=m;i<m<<1;i++)
			b[i]=0;
		inverse(b,d,m);
		init(m<<1);
		for(i=m;i<m<<1;i++)
			b[i]=d[i]=0;
		ll inv2=fp(2,p-2);
		copy_clear(x,a,m);
		ntt(x,1);
		ntt(d,1);
		for(i=0;i<n;i++)
			x[i]=x[i]*d[i]%p;
		ntt(x,-1);
		for(i=0;i<m;i++)
			b[i]=((b[i]+x[i])%p*inv2)%p;
	}
    void derivative(ll *a,ll *b,int m)
	{
		int i;
		for(i=0;i<m-1;i++)
			b[i]=(i+1)*a[i+1]%p;
		b[m-1]=0;
	}
    void differential(ll *a,ll *b,int m)
    {
//    	int i;
//    	for(i=m-1;i>=1;i--)
//    		b[i]=a[i-1]*inv[i]%p;
    	b[0]=0;
    }
    void ln(ll *a,ll *b,int m)
    {
    	static ll c[maxn],d[maxn];
    	derivative(a,c,m);
    	inverse(a,d,m);
    	init(m<<1);
    	int i;
    	for(i=m;i<n;i++)
    		c[i]=d[i]=0;
    	ntt(c,1);
    	ntt(d,1);
    	for(i=0;i<n;i++)
    		c[i]=c[i]*d[i]%p;
    	ntt(c,-1);
    	differential(c,b,m);
    }
    void exp(ll *a,ll *b,int m)
    {
    	if(m==1)
    	{
    		b[0]=1;
    		return;
    	}
    	exp(a,b,m>>1);
    	int i;
    	for(i=m>>1;i<m;i++)
    		b[i]=0;
    	ln(b,y,m);
    	init(m<<1);
    	copy_clear(x,a,m);
    	x[0]++;
    	for(i=0;i<m;i++)
    		x[i]=(x[i]-y[i])%p;
    	copy_clear(y,b,m);
    	ntt(x,1);
    	ntt(y,1);
    	for(i=0;i<n;i++)
    		x[i]=x[i]*y[i]%p;
    	ntt(x,-1);
    	copy(b,x,m);
    }
    void module(ll *a,ll *b,ll *c,int n1,int n2)
    {
    	int k=1;
    	while(k<=n1-n2+1)
    		k<<=1;
    	int i;
    	for(i=0;i<=n1;i++)
    		d[i]=a[i];
    	for(i=0;i<=n2;i++)
    		e[i]=b[i];
    	reverse(d,d+n1+1);
    	reverse(e,e+n2+1);
    	for(i=n1-n2+1;i<k<<1;i++)
    		d[i]=e[i]=0;
    	inverse(e,f,k);
    	for(i=n1-n2+1;i<k<<1;i++)
    		f[i]=0;
    	init(k<<1);
    	ntt::ntt(d,1);
    	ntt::ntt(f,1);
    	for(i=0;i<n;i++)
    		e[i]=d[i]*f[i]%p;
    	ntt::ntt(e,-1);
    	for(i=0;i<=n1-n2;i++)
    		c[i]=e[i];
    	reverse(c,c+n1-n2+1);
    }
};
void add(ll &a,ll b)
{
	a=(a+b)%p;
}
ll inv(ll a)
{
	return fp(a,p-2);
}
ll pw1[maxn];
ll pw2[maxn];
ll q;
ll q2;
ll f[maxn];
ll a[maxn];
ll c[maxn];
ll d[maxn];
ll final[maxn];
ll g[2][maxn];
ll h[maxn];
ll e[maxn];

void mul(ll *a,ll *b,ll *c,int n)
{
	static ll d[maxn],e[maxn];
	int k=1;
	while(k<=n)
		k<<=1;
	ntt::init(k<<1);
	int i;
	for(i=0;i<k<<1;i++)
		d[i]=e[i]=0;
	for(i=0;i<=n;i++)
	{
		d[i]=a[i];
		e[i]=b[i];
	}
	ntt::ntt(d,1);
	ntt::ntt(e,1);
	for(i=0;i<k<<1;i++)
		d[i]=d[i]*e[i]%p;
	ntt::ntt(d,-1);
	//d=a*b
	for(i=0;i<k<<1;i++)
		e[i]=0;
	int n2=(k<<1)-1;
	while(!d[n2])
		n2--;
	ntt::module(d,c,e,n2,n);
	for(i=0;i<n;i++)
		a[i]=d[i];
	for(i=0;i<k;i++)
		d[i]=c[i];
	for(i=k;i<k<<1;i++)
		d[i]=0;
	ntt::init(k<<1);
	ntt::ntt(d,1);
	ntt::ntt(e,1);
	for(i=0;i<k<<1;i++)
		d[i]=d[i]*e[i]%p;
	ntt::ntt(d,-1);
	for(i=0;i<n;i++)
		a[i]=(a[i]-d[i])%p;
}
void powmod(ll *a,ll *b,ll *c,int m,int n)
{
	if(!n)
		return;
	powmod(a,b,c,m,n>>1);
	mul(a,a,c,m);
	if(n&1)
		mul(a,b,c,m);
}
ll solve(int n,int k)
{
	memset(g,0,sizeof g);
	memset(h,0,sizeof h);
	int now=0;
	g[now][1]=q2*pw1[k]%p;
	g[now][0]=1;
	h[0]=1;
	int i,j;
	for(i=k-1;i>=1;i--)
	{
		now^=1;
		int m=k/i;
		ll c=q2*pw1[i]%p;
		int len=1;
		while(len<=m)
			len<<=1;
		for(j=1;j<len;j++)
			e[j]=-c*g[now^1][j-1];
		e[0]=1;
		ntt::inverse(e,h,len);
		for(j=m+1;j<len<<1;j++)
			h[j]=0;
		ntt::init(len<<1);
		ntt::ntt(g[now^1],1);
		ntt::ntt(h,1);
		for(j=0;j<len<<1;j++)
			g[now][j]=g[now^1][j]*h[j]%p;
		ntt::ntt(g[now],-1);
		for(j=m+1;j<len<<1;j++)
			g[now][j]=0;
	}
	memset(a,0,sizeof a);
	for(i=0;i<=k;i++)
		a[i+1]=-g[now][i]*q2%p;
	a[0]=1;
	int len=1;
	while(len<=k+1)
		len<<=1;
	ntt::inverse(a,f,len<<1);
	if(n<=2*(k+1))
	{
		ll s=0;
		for(i=0;i<=n&&i<=k;i++)
			add(s,f[n-i]*g[now][i]);
		return s;
	}
	memset(a,0,sizeof a);
	memset(c,0,sizeof c);
	memset(d,0,sizeof d);
	for(i=0;i<=k;i++)
		a[i]=-g[now][k-i]*q2%p;
	a[k+1]=1;
	if(k)
		c[1]=1;
	else
		c[0]=-a[0];
	d[0]=1;
	int m=n-k;
	powmod(d,c,a,k+1,m);
//	while(m)
//	{
//		if(m&1)
//			mul(d,c,a,k+1);
//		mul(c,c,a,k+1);
//		m>>=1;
////		for(i=0;i<=k;i++)
////			printf("%lld ",(d[i]+p)%p);
////		printf("\n");
//	}
	reverse(d,d+k+1);
	ntt::init(len<<2);
	ntt::ntt(d,1);
	ntt::ntt(f,1);
	for(i=0;i<len<<2;i++)
		final[i]=d[i]*f[i]%p;
	ntt::ntt(final,-1);
	ll s=0;
	for(i=0;i<=k;i++)
		add(s,g[now][i]*final[2*k-i]);
	return s;
//	for(i=0;i<=k;i++)
//		g[now][i]=(g[now][i]+p)%p;
//	memset(f,0,sizeof f);
//	f[0]=1;
//	for(i=1;i<=2*(k+1);i++)
//		for(j=0;j<i&&j<=k;j++)
//			add(f[i],f[i-j-1]*q2%p*g[now][j]);
//	if(n<=2*(k+1))
//	{
//		ll s=0;
//		for(i=0;i<=n&&i<=k;i++)
//			add(s,f[n-i]*g[now][i]);
//		return s;
//	}
//	int len=k+1;
//	for(i=0;i<len;i++)
//		a[i]=-q2*g[now][len-i-1]%p;
//	a[len]=1;
//	memset(c,0,sizeof c);
//	c[1]=1;
//	memset(d,0,sizeof d);
//	d[0]=1;
//	int m=n-k-1;
//	while(m)
//	{
//		if(m&1)
//			mul(d,c,a,len);
//		mul(c,c,a,len);
//		m>>=1;
//	}
//	memset(final,0,sizeof final);
//	for(i=1;i<=k+1;i++)
//		for(j=0;j<=k;j++)
//			add(final[i],d[j]*f[i+j]);
//	ll s=0;
//	for(i=1;i<=k+1;i++)
//		add(s,final[i]*g[now][k+1-i]);
//	return s;
}
int main()
{
	open("bzoj4944");
	int n,k,x,y;
	scanf("%d%d%d%d",&n,&k,&x,&y);
	q=x*inv(y)%p;
	q2=(y-x)*inv(y)%p;
	pw1[0]=pw2[0]=1;
	int i;
	for(i=1;i<=k;i++)
	{
		pw1[i]=pw1[i-1]*q%p;
		pw2[i]=pw2[i-1]*q2%p;
	}
	ll ans1=solve(n,k);
	ll ans2=solve(n,k-1);
	ll ans=((ans1-ans2)%p+p)%p;
	printf("%lld\n",ans);
	return 0;
}
posted @ 2018-03-06 11:29  ywwyww  阅读(1155)  评论(0编辑  收藏  举报