【Luogu P5343】【XR-1】分块
题目链接:
题目大意:
给定两个数组 \(A,B\),现在要给 \(n\) 分组,每块的长度必须在两个数组的交集里(即 \(\sum_{i=1}a_i=n\quad(a_i\in A\cap B)\)),求方案数。
正文:
很明显的一个 DP,设 \(f_i\) 表示长度为 \(i\) 的分组的方案数,那么这么转移:
\[f_i=\sum_{j=1}^{m}f_{i-a_j}
\]
\(m\) 表示最长的可能的块长。
时间复杂度:\(O(nm)\),\(n\leq10^{18}\),这玩意过不去。
考虑用矩阵乘法加速递推。
按照矩乘加速递推的尿性,我们可以用一个 \(m\times 1\) 的矩阵来表示状态:
\[\begin{bmatrix}
f_{m}\\
f_{m-1}\\
\vdots\\
f_1
\end{bmatrix}\]
要想办法转移到 \(\begin{bmatrix} f_{m+1}\\ f_{m}\\ \vdots\\ f_3\\ f_2 \end{bmatrix}\),必须再构造一个 \(m\times m\) 的转移矩阵。
如果要转移到 \(f_{m+1}\),根据转移方程 \(f_i=\sum_{j=1}^{m}f_{i-a_j}\) 和矩阵乘法的运算可以推断转移矩阵的第一行(转移到 \(f_{m+1}\))肯定和 \(A\cap B\) 相关。
拿样例为例。
样例中 \(A\cap B = \{1,2\}\),转移矩阵第一行相应的位置就是 \(1\),比如样例的转移矩阵的第一行就是 \(\begin{bmatrix}1&1&0&0&\cdots&0\end{bmatrix}\)。因为这样,\(f_{m+1}\) 就能被转移得到:
\[\begin{bmatrix}
f_{m}\\
f_{m-1}\\
\vdots\\
f_1
\end{bmatrix} \times\begin{bmatrix}1&1&0&0&\cdots&0\end{bmatrix}=\begin{bmatrix}f_{m+1-1}\times 1+f_{m+1-2}\times 1+f_{m+1-2}\times0+\cdots+f_{1+1-1}\times0\end{bmatrix}=\begin{bmatrix}f_{m+1}\end{bmatrix}\]
接下来 \(m-1\) 行就更好办了,由于目标矩阵第 \(2\) 到第 \(m\) 行就是初始矩阵的第 \(1\) 到第 \(m-1\) 行,那转移矩阵就是:
\[\begin{bmatrix}1&1&\cdots&0&0\\
1&0&\cdots&0&0\\
0&1&\cdots&0&0\\
\vdots&\vdots&\ddots&\vdots&\vdots\\
0&0&\cdots&1&0\end{bmatrix}\]
再套个矩阵快速幂就A了。
代码:
int f[N];
ll n;
bool PR[N], NF[N];
struct matrix
{
ll mat[N][N];
int n, m;
matrix(){memset(mat, 0, sizeof mat);}
inline ll* operator [] (int b) { return mat[b];}
}F, stp;
inline matrix operator*(matrix &a, matrix &b)
{
matrix c; c.n = a.n, c.m = b.m;
for (int i = 1; i <= a.n; i++)
for (int j = 1; j <= b.m; j++)
for (int k = 1; k <= a.m; k++)
c[i][j] = (c[i][j] + (a[i][k] * b[k][j]) % mod) % mod;
return c;
}
matrix qpow(matrix stp, ll b)
{
matrix ans; ans.n = ans.m = stp.n;
for (int i = 1; i <= ans.n; i++)
ans[i][i] = 1;
for (; b; b >>= 1)
{
if(b & 1) ans = ans * stp;
stp = stp * stp;
}
return ans;
}
int m;
int main()
{
scanf ("%lld", &n);
int pr;scanf ("%d", &pr);
for (int i = 1, x; i <= pr; i++)
scanf ("%d", &x), PR[x] = 1;
int nf;scanf ("%d", &nf);
for (int i = 1, x; i <= nf; i++)
scanf ("%d", &x), NF[x] = (1 & PR[x]), m = (NF[x] == 1? max(x, m): m);
stp.n = stp.m = F.n = m, F.m = 1;
f[0] = 1;
for (int i = 1; i < m; i++)
{
for (int j = 1; j <= i; j++)
if(NF[j]) f[i] = (f[i] + f[i - j]) % mod;
F[m - i][1] = f[i];
}
F[m][1] = 1;
for (int i = 1; i <= m; i++)
stp[1][i] = NF[i],
stp[i + 1][i] = 1;
stp[m + 1][m] = 0;
F = qpow(stp, n - m + 1ll) * F;
printf("%lld", F[1][1]);
return 0;
}