【矩阵乘优化DP】涂色游戏
题目大意
用 \(p\) 种颜色填 \(n\times m\) 的画板,要求任意相邻两列的颜色数都不少于 \(q\) ,求方案数。
数据范围
\(1\leq n\leq 100,1\leq m\leq 10^9,q\leq p\leq 100\)
思路
观摩 \(m\) 的范围,显然需要一个 \(\log m\) 的做法,于是想到了矩阵快速幂。
首先考虑原始的转移。若当前一列涂上 \(j\) 种颜色,下一列要涂 \(k\) 种颜色,则方案数如下:
前一个组合数是 \(j\) 和 \(k\) 颜色中交集的部分,而后一个就是交集的补集。其中边界的意思分别为 \(j\cap k=\varnothing\) 和 \(j\subset k\) 或 \(k\subset j\)。意思就是说,这次的方案组成就是在满足条件的情况下,和上次相交的颜色的选择的方案乘上这次的新颜色的选择的方案。
然后对于每一列,定义 \(g[i][j]\),表示当前填到第 \(i\) 行的格子,涂 \(j\) 种颜色的方案数。则 \(g[n][j]\) 一列中涂 \(n\) 种颜色的方案。这个问题可以转化成有 \(j\) 个不同的盒子,要把 \(i\) 个不同的球放入盒子中,要求非空。这个问题就是第二类斯特林,递推式为:
即可以在放过球的盒子中再放一个,有 \(j\) 种,也可以新选一个没有放过球的盒子,这个新的盒子可以是 \(j\) 中的任何一个。因此一共 \(j\) 种。由于每一列的情况都是类似的,所以可以预处理出来。
那么转移矩阵就出来了。设 \(h[j][k]\) 表示这一列涂 \(j\) 种颜色,下一列涂 \(k\) 种颜色的方案数:
则令 \(f[i][j]\) 为当前选到第 \(i\) 列,当前一列涂了 \(j\) 种颜色的方案数,则可以得到 \(f[i][j]=f[i-1][j]\times h[j][k]\),边界为 \(f[1][j]=g[n][j]\times C_p^j\),表示选 \(j\) 种颜色后涂上。由于 \(f\) 的转移系数与 \(i\) 无关,所以可以用矩阵快速幂优化转移 \(m-1\) 次后得到结果,时间复杂度 \(O(n^3\log m)\)。
代码
注意卡常。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=100+10;
const int Mod=998244353;
int n,m,p,q;
ll res;
ll C[maxn][maxn],g[maxn][maxn];
inline ll add(ll x,ll y){
if(x+y>Mod)return x+y-Mod;
return x+y;
}
struct Mat{
ll a[maxn][maxn];
Mat(){
memset(a,0,sizeof(a));
}
inline void set(){
for(int i=1;i<=n;i++)
a[i][i]=1;
}
friend inline Mat operator *(register const Mat& A,register const Mat& B){
Mat C;
for(register int i=1;i<=p;i++)
for(register int j=1;j<=p;j++)
for(register int k=1;k<=p;k++)
C.a[i][j]=add(C.a[i][j],A.a[i][k]*B.a[k][j]%Mod);
return C;
}
}f,h;
inline int read(){
int x=0;bool fopt=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar())if(ch=='-')fopt=0;
for(;isdigit(ch);ch=getchar())x=(x<<3)+(x<<1)+ch-48;
return fopt?x:-x;
}
inline Mat qpow(Mat x,int b){
Mat ans,base=x;
ans.set();
while(b){
if(b&1)ans=ans*base;
base=base*base;
b>>=1;
}
return ans;
}
inline void Init(){
g[0][0]=C[0][0]=1;
for(int i=1;i<=100;i++){
C[i][0]=1;
for(int j=1;j<=i;j++)
C[i][j]=add(C[i-1][j],C[i-1][j-1]);
}
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
g[i][j]=add(g[i-1][j],g[i-1][j-1])*j%Mod;
}
signed main(){
n=read();m=read();p=read();q=read();
Init();
for(register int j=1;j<=p;j++)
for(register int k=1;k<=p;k++){
int l=max(max(q,j),k),r=min(p,j+k);
for(register int i=l;i<=r;i++)
h.a[j][k]=add(h.a[j][k],C[j][j+k-i]*C[p-j][i-j]%Mod);
h.a[j][k]=h.a[j][k]*g[n][k]%Mod;
}
f=qpow(h,m-1);
for(register int i=1;i<=p;i++)
for(register int j=1;j<=p;j++)
res=add(res,f.a[i][j]*C[p][i]%Mod*g[n][i]%Mod);//简单易懂的求和
printf("%lld\n",res);
return 0;
}