NFLSOJ829 【2020六校联考WC #1】LJJ的生日礼物
NFLSOJ829 【2020六校联考WC #1】LJJ的生日礼物
题目大意
一个长度为 \(L\) 的序列。有 \(K\) 种颜色。你需要给每个位置染上一种颜色,使得没有距离 \(\leq 2\) 的两个位置颜色相同。有 \(N\) 个位置 \(p_{1\dots N}\),它们的颜色已经确定了,分别是 \(c_{1\dots N}\)。请求出给剩下位置染色的方案数。答案对 \(10^9 + 7\) 取模。
数据范围:\(0\leq N\leq 1000\),\(1\leq K\leq 10^9\),\(\max(1, N)\leq L\leq 10^9\)。保证 \(p_1 < p_2\dots < p_N\)。
本题题解
考虑朴素的 DP。设 \(\mathrm{dp}_1(i, j, k)\) 表示考虑了序列的前 \(i\) 个位置(\(2\leq i\leq L\)),第 \(i - 1\) 个位置颜色为 \(j\),第 \(i\) 个位置颜色为 \(k\) 的染色方案数。时间复杂度 \(\mathcal{O}(LK^3)\)。
设 \(c_{1\dots N}\) 中总共出现了 \(W\) 种颜色(\(W\leq N\))。可以把剩下所有颜色,视为一种特殊颜色。具体来说,在 DP 状态里,如果 \(j\leq W\),它表示的是某一种出现过的颜色,正常转移即可;否则 \(j = W + 1\),即特殊颜色,这个状态表示:【$j = $ 剩下任意一种颜色时,的 DP 值】之和(这些颜色的 DP 值显然是完全一样的,它们的转移也是完全一样的),于是转移时乘以系数 \(K - W\) 即可。时间复杂度 \(\mathcal{O}(LW^3) = \mathcal{O}(LN^3)\)。
继续优化,我们要让时间复杂度摆脱 \(L\),就必须抛弃上述的、逐个位置 DP 的想法,直接在 \(N\) 个关键点之间进行转移。于是有了一个新的状态设计:设 \(\mathrm{dp}_2(i, j)\) 表示考虑了前 \(i\) 个关键点,\(p_i + 1\) 位置上颜色为 \(j\) 的方案数。转移时,枚举 \(p_{i + 1} + 1\) 的颜色 \(j'\)。要从 \(\mathrm{dp}_2(i, j)\) 转移到 \(\mathrm{dp}_2(i + 1, j')\)。我们想快速求出转移系数。问题可以转化为:有一段长度为 \(p_{i + 1} - p_{i} + 2\) 的序列,开头两个位置颜色分别为 \(c_i, j\),末尾两个位置颜色分别为 \(c_{i + 1}, j'\),求给中间其他位置染色的方案数。
发现这个问题的答案,只和 \(c_{i}, j, c_{i + 1}, j'\) 这四个颜色两两是否相等有关,与它们具体是什么无关。换句话说,等价的情况只有 \(\mathrm{Bell}(4) = 15\) 种(实际更少,因为例如 \(c_i = j\) 等情况是不合法的)。并且因为长度为 \(p_{i + 1} - p_{i} + 2\),共 \(N - 1\) 种,所以只需要对这 \((N - 1)\cdot\mathrm{Bell}(4)\) 个问题分别预处理答案即可。
如何预处理答案?朴素的想法还是做第一个 DP。即 \(\text{dp}_1(i, j, k)\) 表示前 \(i\) 个位置,最后两个位置颜色分别为 \(j,k\) 的方案数。关键的颜色在这里只有 \(4\) 种(\(c_i, j, c_{i + 1}, j'\)),其他颜色可以合并为一种特殊颜色,故状态数是 \(5\times 5 = 25\) 的。进一步,甚至 \(c_i, j\) 这两种颜色也不重要,我们只把 \(c_{i + 1}, j'\) 视为关键颜色,就可以完成 DP。状态数优化为 \(3\times 3 = 9\) 种。减去连续两个颜色相等,这样不合法的状态后,只剩 \(7\) 种。朴素 DP 时间复杂度 \(\mathcal{O}(L \cdot 7)\)。因为状态数 \(7\) 很小,而长度 \(L\) 很大,考虑用矩阵快速幂,可以优化为 \(\mathcal{O}(7^3\log L)\)。
于是,可以在 \(\mathcal{O}(N\cdot \mathrm{Bell}(4)\cdot 7^3\cdot \log L)\) 的时间复杂度内完成所有预处理。接下来用预处理的信息,可以 \(\mathcal{O}(1)\) 回答从 \(\mathrm{dp}_2(i, j)\) 向 \(\mathrm{dp}_2(i + 1, j')\) 转移的系数。现在,这个 DP 的时间复杂度是 \(\mathcal{O}(NK^2)\)。
发现对所有 \(j'\neq j, j'\neq c_i\),\(\mathrm{dp}_2(i,j)\) 对它们的贡献是一样的。于是就没必要用两层循环,先枚举 \(j\) 再枚举 \(j'\) 了。可以先求出系数和,再加到所有 \(j'\) 上。时间复杂度 \(\mathcal{O}(NK)\)。
最后,把除了 \(c_{1\dots N}\) 外,其他颜色视为一种特殊颜色。可以优化为 \(\mathcal{O}(N^2)\)。
总时间复杂度 \(\mathcal{O}(N\cdot \mathrm{Bell}(4)\cdot 7^3\cdot \log L + N^2)\)。
参考代码
// problem: NFLSOJ829
#include <bits/stdc++.h>
using namespace std;
#define mk make_pair
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
template<typename T> inline void ckmax(T& x, T y) { x = (y > x ? y : x); }
template<typename T> inline void ckmin(T& x, T y) { x = (y < x ? y : x); }
const int MAXN = 1000;
const int MOD = 1e9 + 7;
inline int mod1(int x) { return x < MOD ? x : x - MOD; }
inline int mod2(int x) { return x < 0 ? x + MOD : x; }
inline void add(int &x, int y) { x = mod1(x + y); }
inline void sub(int &x, int y) { x = mod2(x - y); }
inline int pow_mod(int x, int i) {
int y = 1;
while (i) {
if (i & 1) y = (ll)y * x % MOD;
x = (ll)x * x % MOD;
i >>= 1;
}
return y;
}
int n, K, L;
int p[MAXN + 5], c[MAXN + 5];
int cols[MAXN + 5], cnt_col;
bool special_cases() {
for (int i = 1; i < n; ++i) {
if (p[i + 1] - p[i] <= 2 && c[i] == c[i + 1]) {
cout << 0 << endl;
return true;
}
}
if (K == 0) {
cout << 0 << endl;
return true;
}
if (K == 1) {
cout << (L == 1) << endl;
return true;
}
if (K == 2) {
if (L == 1) {
cout << (n == 1 ? 1 : 2) << endl;
} else if (L == 2) {
if (n == 0) {
cout << 2 << endl;
} else if (n == 1) {
cout << 1 << endl;
} else {
cout << (c[1] != c[2]) << endl;
}
} else { // L > 2
cout << 0 << endl;
}
return true;
}
if (n == 0) {
if (L == 1) {
cout << K << endl;
return true;
}
cout << (ll)K * (K - 1) % MOD * pow_mod(K - 2, L - 2) % MOD << endl;
return true;
}
if (p[1] == L) {
if (L == 1) {
cout << 1 << endl;
return true;
}
cout << (ll)(K - 1) * pow_mod(K - 2, L - 2) % MOD << endl;
return true;
}
cerr << "* no special case" << endl;
return false;
}
int dp[MAXN + 5][MAXN + 5]; // 第 i 个固定点, a[p[i] + 1] = j
/*
int g[3][3][55][3][3];
void brute_force_dp(int g[55][3][3], int x, int y) {
g[2][x][y] = 1;
for (int i = 3; i <= 50; ++i) {
add(g[i][1][2], (ll)g[i - 1][0][1] * (K - 2) % MOD);
add(g[i][0][2], (ll)g[i - 1][1][0] * (K - 2) % MOD);
add(g[i][2][1], g[i - 1][0][2]);
if (K > 3)
add(g[i][2][2], (ll)g[i - 1][0][2] * (K - 3) % MOD);
add(g[i][0][1], g[i - 1][2][0]);
if (K > 3)
add(g[i][0][2], (ll)g[i - 1][2][0] * (K - 3) % MOD);
add(g[i][2][0], g[i - 1][1][2]);
if (K > 3)
add(g[i][2][2], (ll)g[i - 1][1][2] * (K - 3) % MOD);
add(g[i][1][0], g[i - 1][2][1]);
if (K > 3)
add(g[i][1][2], (ll)g[i - 1][2][1] * (K - 3) % MOD);
add(g[i][2][1], g[i - 1][2][2]);
add(g[i][2][0], g[i - 1][2][2]);
if (K > 4)
add(g[i][2][2], (ll)g[i - 1][2][2] * (K - 4) % MOD);
}
}
*/ // 把以上的暴力 DP 换成矩阵快速幂!
/*
g[x][y] 表示最末尾的两种颜色是 x, y
0, 1 是 p[i + 1], p[i + 1] + 1 的颜色, 2 代表其他颜色
状态共有 7 种:
g[0][1] -> mat[0]
g[1][0] -> mat[1]
g[0][2] -> mat[2]
g[2][0] -> mat[3]
g[1][2] -> mat[4]
g[2][1] -> mat[5]
g[2][2] -> mat[6]
*/
struct Matrix {
int a[7][7];
void identity() {
for (int i = 0; i < 7; ++i)
for (int j = 0; j < 7; ++j)
a[i][j] = (i == j);
}
void clear() {
for (int i = 0; i < 7; ++i)
for (int j = 0; j < 7; ++j)
a[i][j] = 0;
}
Matrix() {
clear();
}
};
Matrix operator * (const Matrix& X, const Matrix& Y) {
Matrix Z;
for (int i = 0; i < 7; ++i) {
for (int j = 0; j < 7; ++j) {
for (int k = 0; k < 7; ++k) {
Z.a[i][j] = ((ll)Z.a[i][j] + (ll)X.a[i][k] * Y.a[k][j]) % MOD;
}
}
}
return Z;
}
Matrix mat_pow(Matrix X, int i) {
Matrix Y;
Y.identity();
while (i) {
if (i & 1) Y = Y * X;
X = X * X;
i >>= 1;
}
return Y;
}
Matrix Trans;
int mat_idx[3][3];
int f[MAXN + 5][3][3][3][3];
map<int, int> id;
int cnt_id;
int id_len[MAXN + 5];
int makef(int len) {
if (id.count(len)) return id[len];
id[len] = ++cnt_id;
assert(len >= 2);
Matrix A = mat_pow(Trans, len - 2);
for (int i = 0; i <= 2; ++i) {
for (int j = 0; j <= 2; ++j) {
if ((i == 0 && j == 0) || (i == 1 && j == 1))
continue;
Matrix B;
B.a[0][mat_idx[i][j]] = 1;
B = B * A;
for (int x = 0; x <= 2; ++x) {
for (int y = 0; y <= 2; ++y) {
if ((x == 0 && y == 0) || (x == 1 && y == 1))
continue;
f[cnt_id][i][j][x][y] = B.a[0][mat_idx[x][y]];
}
}
}
}
return cnt_id;
}
void init() {
// memset(g, 0, sizeof(g));
// for (int i = 0; i <= 2; ++i) for (int j = 0; j <= 2; ++j) brute_force_dp(g[i][j], i, j);
// 构造转移矩阵:
Trans.clear();
Trans.a[0][4] = K - 2; // g[0][1] -> g[1][2]
Trans.a[1][2] = K - 2; // g[1][0] -> g[0][2]
Trans.a[2][5] = 1; // g[0][2] -> g[2][1]
if (K > 3) Trans.a[2][6] = K - 3; // g[0][2] -> g[2][2]
Trans.a[3][0] = 1; // g[2][0] -> g[0][1]
if (K > 3) Trans.a[3][2] = K - 3; // g[2][0] -> g[0][2]
Trans.a[4][3] = 1; // g[1][2] -> g[2][0]
if (K > 3) Trans.a[4][6] = K - 3; // g[1][2] -> g[2][2]
Trans.a[5][1] = 1; // g[2][1] -> g[1][0]
if (K > 3) Trans.a[5][4] = K - 3; // g[2][1] -> g[1][2]
Trans.a[6][5] = 1; // g[2][2] -> g[2][1]
Trans.a[6][3] = 1; // g[2][2] -> g[2][0]
if (K > 4) Trans.a[6][6] = K - 4; // g[2][2] -> g[2][2]
mat_idx[0][1] = 0;
mat_idx[1][0] = 1;
mat_idx[0][2] = 2;
mat_idx[2][0] = 3;
mat_idx[1][2] = 4;
mat_idx[2][1] = 5;
mat_idx[2][2] = 6;
cnt_id = 0;
id.clear();
for (int i = 1; i < n; ++i) {
id_len[i] = makef(p[i + 1] - p[i] + 2);
}
if (p[n] == L) {
makef(p[n] - p[n - 1] + 1);
}
}
int calc(int id, int s1, int s2, int t1, int t2) {
int c1 = (s1 == t1 ? 0 : (s1 == t2 ? 1 : 2));
int c2 = (s2 == t1 ? 0 : (s2 == t2 ? 1 : 2));
return f[id][c1][c2][0][1];
}
void solve_case() {
cin >> n >> K >> L;
for (int i = 1; i <= n; ++i) {
cin >> p[i] >> c[i];
++p[i];
}
if (special_cases())
return;
cnt_col = 0;
for (int i = 1; i <= n; ++i) {
cols[++cnt_col] = c[i];
}
sort(cols + 1, cols + cnt_col + 1);
cnt_col = unique(cols + 1, cols + cnt_col + 1) - (cols + 1);
for (int i = 1; i <= n; ++i) {
c[i] = lower_bound(cols + 1, cols + cnt_col + 1, c[i]) - cols;
}
assert(cnt_col <= K);
init();
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= cnt_col + 1; ++j)
dp[i][j] = 0;
int w = pow_mod(K - 2, p[1] - 1);
for (int j = 1; j <= cnt_col; ++j) {
if (j == c[1])
continue;
dp[1][j] = w;
}
dp[1][cnt_col + 1] = (ll)w * (K - cnt_col) % MOD;
for (int i = 1; i < n; ++i) {
// dp[i][j1] -> dp[i + 1][j2]
if (p[i] + 1 == p[i + 1]) {
if (p[i + 1] == L) {
cerr << "* last two places fixed" << endl;
cout << dp[i][c[i + 1]] << endl;
return;
}
for (int j2 = 1; j2 <= cnt_col; ++j2) {
// a[p[i + 1] + 1] = j2
if (j2 == c[i + 1] || j2 == c[i])
continue;
dp[i + 1][j2] = dp[i][c[i + 1]];
}
dp[i + 1][cnt_col + 1] = (ll)dp[i][c[i + 1]] * (K - cnt_col) % MOD;
continue;
}
if (p[i + 1] == L) {
cerr << "* last place fixed" << endl;
int ans = 0;
for (int j1 = 1; j1 <= cnt_col + 1; ++j1) {
if (j1 == c[i])
continue;
if (p[i + 1] - p[i] <= 3 && j1 == c[i + 1])
continue;
int len = p[i + 1] - p[i] + 1;
int l = id[len];
if (c[i] == c[i + 1]) {
add(ans, (ll)dp[i][j1] * (f[l][0][1][1][0] + f[l][0][1][2][0]) % MOD);
} else {
if (j1 == c[i + 1]) {
add(ans, (ll)dp[i][j1] * (f[l][0][1][0][1] + f[l][0][1][2][1]) % MOD);
} else {
add(ans, (ll)dp[i][j1] * (f[l][0][2][0][1] + f[l][0][2][2][1]) % MOD);
}
}
}
cout << ans << endl;
return;
}
int val = 0;
for (int j1 = 1; j1 <= cnt_col + 1; ++j1) {
if (j1 == c[i])
continue;
if (p[i + 1] - p[i] <= 3 && j1 == c[i + 1])
continue;
// j2 的特殊点: j1, c[i], cnt_col + 1
int j2 = 1;
while (j2 <= cnt_col && (j2 == j1 || j2 == c[i] || j2 == c[i + 1])) j2++; // 找到一般的 j2
if (j2 <= cnt_col) {
int _val = (ll)dp[i][j1] * calc(id_len[i], c[i], j1, c[i + 1], j2) % MOD;
add(val, _val);
if (j1 <= cnt_col && j1 != c[i + 1])
sub(dp[i + 1][j1], _val);
/*
for (int j2 = 1; j2 <= cnt_col; ++j2) {
if (j2 == j1 || j2 == c[i] || j2 == c[i + 1])
continue;
add(dp[i + 1][j2], _val);
}
*/
}
if (j1 <= cnt_col && p[i] + 2 < p[i + 1] && j1 != c[i + 1]) {
// j2 == j1
// j2 必 <= cnt_col
add(dp[i + 1][j1], (ll)dp[i][j1] * calc(id_len[i], c[i], j1, c[i + 1], j1) % MOD);
}
if (c[i] != c[i + 1]) {
// j2 == c[i]
// j2 必 != j1
// j2 必 <= cnt_col
add(dp[i + 1][c[i]], (ll)dp[i][j1] * calc(id_len[i], c[i], j1, c[i + 1], c[i]) % MOD);
}
if (cnt_col < K) {
// j2 == cnt_col + 1
if (j1 == cnt_col + 1) {
// 1. j2 实际等于 j1
add(dp[i + 1][cnt_col + 1], (ll)dp[i][j1] * calc(id_len[i], c[i], j1, c[i + 1], j1) % MOD);
// 2. j2 实际不等于 j1
if (K - cnt_col >= 2)
add(dp[i + 1][cnt_col + 1], (ll)dp[i][j1] * calc(id_len[i], c[i], j1, c[i + 1], j1 + 1) % MOD * (K - cnt_col - 1) % MOD);
} else {
add(dp[i + 1][cnt_col + 1], (ll)dp[i][j1] * calc(id_len[i], c[i], j1, c[i + 1], cnt_col + 1) % MOD * (K - cnt_col) % MOD);
}
}
/*
for (int j2 = 1; j2 <= cnt_col + 1; ++j2) {
if (j2 == c[i + 1])
continue;
if (p[i] + 2 == p[i + 1] && j1 <= cnt_col && j2 <= cnt_col && j1 == j2)
continue;
if (j1 == cnt_col + 1 && j2 == cnt_col + 1) {
// 1. j2 实际等于 j1
if (K - cnt_col >= 1)
add(dp[i + 1][j2], (ll)dp[i][j1] * calc(p[i + 1] - p[i] + 2, c[i], j1, c[i + 1], j2) % MOD);
// 2. j2 实际不等于 j1
if (K - cnt_col >= 2)
add(dp[i + 1][j2], (ll)dp[i][j1] * calc(p[i + 1] - p[i] + 2, c[i], j1, c[i + 1], j2 + 1) % MOD * (K - cnt_col - 1) % MOD);
continue;
}
int w = (j2 == cnt_col + 1 ? K - cnt_col : 1);
add(dp[i + 1][j2], (ll)dp[i][j1] * calc(p[i + 1] - p[i] + 2, c[i], j1, c[i + 1], j2) % MOD * w % MOD);
}
*/
}
for (int j2 = 1; j2 <= cnt_col; ++j2) {
if (j2 == c[i] || j2 == c[i + 1])
continue;
add(dp[i + 1][j2], val);
}
}
int ans = 0;
int tmp = pow_mod(K - 2, L - p[n] - 1);
for (int j = 1; j <= cnt_col + 1; ++j) {
add(ans, (ll)dp[n][j] * tmp % MOD);
}
cout << ans << endl;
}
int main() {
int T; cin >> T;
for (int t = 1; t <= T; ++t) {
cout << "Case #" << t << ": ";
cerr << endl;
solve_case();
}
return 0;
}