【XSY2745】装饰地板 状压DP 特征多项式

题目大意

  你有\(s_1\)\(1\times 2\)的地砖,\(s_2\)\(2\times 1\)的地砖。

  记铺满\(m\times n\)的地板的方案数为\(f(m,n)\)

  给你\(m,l,r,s_1,s_2\),求\(\sum_{i=l}^rf(m,i)\)

  \(m\leq 6,1\leq l\leq r\leq {10}^{2501}\)

题解

  显然是状压DP。

  显然可以矩阵快速幂。

  怎么矩阵快速幂?

  假设矩阵是\(2^m\times 2^m\)的,我们把矩阵扩大一行一列,记录前面算出的铺满\(m\)\(i\)列的方案数。

  转移在原来转移的基础上增加\(f_{i-1,2^m-1}\longrightarrow f_{i,2^m}\)\(f_{i-1,2^m}\longrightarrow f_{i,2^m}\)

  这样\(f_{i,m}\)就是\(\sum_{j=1}^{i-1}f_{j,2^m-1}\)

  然后就可以矩阵快速幂了。

  显然这个矩阵快速幂是可以用特征多项式+倍增取模优化的。

  求特征多项式可以用\(O(n^3)\)的方法,也可以用\(O(n^4)\)的方法。

  然后就没了。

  时间复杂度:\(O(8^m+4^m\log r)\)

  求矩阵的特征多项式

  Cayley-Hamilton定理&倍增取模

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<utility>
#include<iostream>
#include<vector>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
void open(const char *s)
{
#ifndef ONLINE_JUDGE
	char str[100];
	sprintf(str,"%s.in",s);
	freopen(str,"r",stdin);
	sprintf(str,"%s.out",s);
	freopen(str,"w",stdout);
#endif
}
typedef vector<ll> poly;
const ll p=998244353;
ll fp(ll a,ll b)
{
	ll s=1;
	for(;b;b>>=1,a=a*a%p)
		if(b&1)
			s=s*a%p;
	return s;
}
struct matrix
{
	ll a[130][130];
	int n,m;
	matrix()
	{
		memset(a,0,sizeof a);
		n=m=0;
	}
	ll *operator [](int x)
	{
		return a[x];
	}
};
matrix operator *(matrix a,matrix b)
{
	matrix c;
	c.n=a.n;
	c.m=b.m;
	for(int i=0;i<a.n;i++)
		for(int j=0;j<a.m;j++)
		{
			ll s=0;
			for(int k=0;k<b.m;k++)
				s=(s+a[i][k]*b[k][j])%p;
			c[i][j]=s;
		}
	return c;
}
namespace yww
{
	poly f[200];
	void add(poly &a,poly b,ll s1,ll s2)
	{
		int n=b.size();
		while(int(a.size())<n+1)
			a.push_back(0);
		for(int i=0;i<n;i++)
		{
			a[i]=(a[i]+b[i]*s2)%p;
			a[i+1]=(a[i+1]+b[i]*s1)%p;
		}
		while(a.back()==0)
			a.pop_back();
	}
	poly getpoly(matrix a,int n)
	{
		for(int i=0;i<=n;i++)
		{
			int j;
			for(j=i+1;j<=n;j++)
				if(a[j][i])
					break;
			if(j>n)
				continue;
			if(j!=i+1)
			{
				for(int k=i;k<=n;k++)
					swap(a[i+1][k],a[j][k]);
				for(int k=0;k<=n;k++)
					swap(a[k][i+1],a[k][j]);
			}
			for(int j=i+2;j<=n;j++)
				if(a[j][i])
				{
					ll v=fp(a[i+1][i],p-2)*a[j][i]%p;
					for(int k=i;k<=n;k++)
						a[j][k]=(a[j][k]-a[i+1][k]*v)%p;
					for(int k=0;k<=n;k++)
						a[k][i+1]=(a[k][i+1]+a[k][j]*v)%p;
				}
		}
		f[n+1].push_back(1);
		for(int i=n;i>=0;i--)
		{
			add(f[i],f[i+1],1,-a[i][i]);
			ll v=1;
			for(int j=i+2;j<=n+1;j++)
			{
				v=v*a[j-1][j-2]%p;
				add(f[i],f[j],0,-v*a[i][j-1]%p);
			}
		}
		return f[0];
	}
}
matrix a;
int m,s1,s2,all;
void dfs(int x,int a1,int a2,ll s)
{
	if(x>m+1)
		return;
	if(x>m)
	{
		a[all^a1][a2]=(a[all^a1][a2]+s)%p;
		return;
	}
	dfs(x+1,a1,a2,s);
	dfs(x+1,a1|(1<<(x-1)),a2|(1<<(x-1)),s*s1%p);
	dfs(x+2,a1,a2|(3<<(x-1)),s*s2%p);
}
poly aa;
int len;
//poly f[20];
matrix g[200];
poly operator *(poly a,poly b)
{
	int n=a.size()-1;
	int m=b.size()-1;
	poly c(n+m+1);
	for(int j=0;j<=m;j++)
		if(b[j])
			for(int i=0;i<=n;i++)
				c[i+j]=(c[i+j]+a[i]*b[j])%p;
	return c;
}
poly operator %(poly a,poly b)
{
	int n=a.size()-1;
	int m=b.size()-1;
	for(int i=n;i>=m;i--)
		if(a[i])
		{
			ll v=a[i];
			for(int j=0;j<=m;j++)
				a[i-m+j]=(a[i-m+j]-b[j]*v)%p;
		}
	while(!a.back())
		a.pop_back();
	return a;
}
poly a1;
void init()
{
	all=(1<<m)-1;
	dfs(1,0,0,1);
	a[all][all+1]=1;
	a[all+1][all+1]=1;
	a.n=a.m=all+2;
	aa=yww::getpoly(a,all+1);
	len=aa.size()-1;
//	for(int i=0;i<=len;i++)
//		c[i]=aa[i];
//	f[0].push_back(1);
//	for(int i=1;i<=9;i++)
//	{
//		f[i].resize(i+1);
//		f[i][i]=1;
//		f[i]=f[i]%aa;
////		for(auto v:f[i])
////			printf("%lld ",(v+p)%p);
////		printf("\n");
//	}
	g[0][0][all]=1;
	g[0].n=1;
	g[0].m=a.n;
	for(int i=1;i<len;i++)
	{
		g[i]=g[i-1]*a;
//		printf("%lld\n",(g[i][0][all+1]+p)%p);
	}
	a1.resize(2);
	a1[0]=0;
	a1[1]=1;
	a1=a1%aa;
}
char l[10010],r[10010];
int e[100010];
int bit[100010];
poly f;
void calc(int n)
{
	if(!n)
		return;
	calc(n-1);
	f=f*f%aa;
	if(bit[n])
		f=f*a1%aa;
}
int solve(char *str,int b)
{
	int n=strlen(str+1);
	memset(e,0,sizeof e);
	for(int i=1;i<=n;i++)
		e[i]=str[n-i+1]-'0';
	e[1]+=b;
	int i;
	for(i=1;e[i]>=10;i++)
	{
		e[i+1]+=e[i]/10;
		e[i]%=10;
	}
	n=max(n,i);
	memset(bit,0,sizeof bit);
	int k=1;
	for(int i=n;i>=1;i--)
	{
		int s=0;
		int j;
		for(j=1;j<=k||s;j++)
		{
			s+=bit[j]*10;
			bit[j]=s&1;
			s>>=1;
		}
		k=max(k,j);
		bit[1]+=e[i];
		for(j=1;bit[j]>=2;j++)
		{
			bit[j+1]+=bit[j]>>1;
			bit[j]&=1;
		}
		k=max(k,j);
	}
	f.clear();
	f.push_back(1);
	reverse(bit+1,bit+k+1);
	calc(k);
	ll ans=0;
	for(int i=0;i<f.size()&&i<len;i++)
	{
//		printf("%lld\n",(f[i]+p)%p);
		ans=(ans+f[i]*g[i][0][all+1])%p;
	}
	return ans;
}
int main()
{
	open("b");
	scanf("%s%s",l+1,r+1);
	scanf("%d%d%d",&m,&s1,&s2);
	init();
	ll ans1=solve(l,0);
	ll ans2=solve(r,1);
	ll ans=ans2-ans1;
	ans=(ans%p+p)%p;
	printf("%lld\n",ans);
	return 0;
}
posted @ 2018-03-13 20:08  ywwyww  阅读(222)  评论(0编辑  收藏  举报