[Luogu]P2182 翻硬币
Before the Beginning
转载请将本段放在文章开头显眼处,如有二次创作请标明。
原文链接:https://www.codein.icu/lp2182/
题意
\(N\) 个硬币,一次翻转恰好 \(M\) 个,恰好 \(K\) 次后到达目标状态的方案数。
朴素解法
先将正反面转化为是否与目标状态相同方便处理。
定义 \(f(i,k)\) 为翻转了 \(i\) 次,与目标状态有 \(k\) 个硬币相同的方案数。
考虑转移,枚举翻转 \(m\) 个硬币中,\(j\) 个硬币从相同翻为不同, \(m - j\) 个硬币从不同翻为相同。
那么计算得出新的相同数量即为 \(k + m - 2 \times j\),判断是否在合法范围内。
从状态 \(f(i-1,k)\) 转移到 \(f(i,k + m - 2 \times j)\) 时,要在 \(k\) 选择 \(j\) 个相同的硬币翻转为不同,\(n - k\) 中选择 \(m - j\) 个硬币翻转为相同,乘上相应的组合数即可。
可以得到朴素的 DP 解法,复杂度约为 \(O(kn^2)\)。
代码实现中使用了滚动数组优化空间。
#include <stdio.h>
#include <memory.h>
const int maxn = 110;
const int mod = 1000000007;
int n,k,m;
int dif;
char s1[maxn],s2[maxn];
long long f[2][maxn];//f[i] i sames.
long long c[maxn][maxn];
int main()
{
scanf("%d %d %d",&n,&k,&m);
c[1][1] = 1;
for(int i = 2;i<=n+1;++i)
for(int j = 1;j<=i;++j)
c[i][j] = (c[i-1][j] + c[i-1][j-1]) % mod;
scanf("%s",s1+1);
scanf("%s",s2+1);
for(int i = 1;i<=n;++i) dif += (s1[i] != s2[i]);
int now = 0,last = 1;
f[last][n - dif] = 1;
while(k--)
{
memset(f[now],0,sizeof(f[now]));
for(int i = 0;i<=n;++i)// origin same
for(int j = 0;j<=m;++j)//choose j same->diff, n - j diff->same
{
int num = i - j + m - j;
if(num > n || num < 0) continue;
f[now][num] += ((f[last][i] * c[i+1][j+1]) % mod * c[n-i+1][m-j+1]) % mod;
if(f[now][num] > mod) f[now][num] %= mod;
}
now ^= 1,last ^= 1;
}
printf("%lld",f[last][n] % mod);
return 0;
}
矩乘优化
不难发现,每次的转移都是在累加上次的某个状态的结果乘上某个系数,并且系数是不变的。
\(f(i,k + m - 2 \times j) += f(i - 1,k) \times \ldots\)
那么整理一下,一定可以写成:
\(f(i,k) = \sum f(i-1,j) \times A(j,k)\)
可以使用矩阵乘法来加速DP,复杂度减少为 \(O(n^3 \log k)\)。
具体地,用原先的DP过程来构造 \(A\) 矩阵,用矩阵快速幂获得 \(A^k\),再乘上初始状态矩阵即可获得最终状态矩阵。
该方法在 \(n\) 较小, \(k\) 较大时表现出色,但在本题中并无优势。
代码实现中,将矩阵整体下标加一避免零的出现。
#include <cstdio>
const int maxn = 110;
const long long mod = 1000000007;
int n, m, k, cnt;
char s1[maxn], s2[maxn];
long long c[maxn][maxn];
inline void init()
{
c[1][1] = 1ll;
for (int i = 2; i <= n + 5; ++i)
for (int j = 1; j <= i; ++j)
c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod;
}
struct matrix
{
int n, m;
long long a[maxn][maxn];
matrix(int n, int m)
{
this->n = n, this->m = m;
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= m; ++j)
a[i][j] = 0;
}
matrix operator*(const matrix b)
{
matrix res(n, b.m);
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= m; ++j)
for (int k = 1; k <= b.m; ++k)
res.a[i][k] = (res.a[i][k] + a[i][j] * b.a[j][k]) % mod;
return res;
}
};
signed main()
{
scanf("%d %d %d", &n, &k, &m), init();
scanf("%s", s1 + 1), scanf("%s", s2 + 1);
for (int i = 1; i <= n; ++i) cnt += (s1[i] == s2[i]);
matrix A(n + 1, n + 1), S(1, n + 1), T(n + 1, n + 1);
for (int i = 0; i <= n; ++i) for (int j = 0; j <= m; ++j)
{
int num = i + m - 2 * j;
if (num > n || num < 0) continue;
A.a[i + 1][num + 1] = (A.a[i + 1][num + 1] + c[i + 1][j + 1] * c[n - i + 1][m - j + 1]) % mod;
}
S.a[1][cnt + 1] = 1;
for (int i = 1; i <= T.n; ++i) T.a[i][i] = 1;
for (; k; k >>= 1)
{
if (k & 1) T = T * A;
A = A * A;
}
S = S * T;
printf("%lld\n", S.a[1][n + 1]);
return 0;
}