CSP-S 2023 消消乐
DP,Hash,ad-hoc
算法 1
考虑区间 DP。设 \(f(i, j)\) 表示区间 \([i, j]\) 是否可消除,则
最后统计多少字符串是可消除的。时间复杂度为 \(O(n^3)\)。
算法 2
我们枚举所有字串,然后判断,时间复杂度同样为 \(O(n^3)\),但是这是没必要的。我们统计从 \(i\) 开始的可消除子串数量,相当于做一遍括号匹配。统计栈空的时候即可。时间复杂度为 \(O(n^2)\)。
std::cin >> n >> s;
for (int i = 0; i < n; i++) {
std::stack<int> st;
for (int j = i; j < n; j++) {
if (!st.empty() && s[st.top()] == s[j]) {
st.pop();
} else {
st.push(j);
}
ans += st.empty();
}
}
std::cout << ans << '\n';
算法 3
设 \(f(i)\) 表示以 \(i\) 结尾的可消除序列个数,设 \([i', i]\) 可消除,则 \(f(i) = f(i') + 1\)。则答案为 \(\sum f\)。时间复杂度为 \(O(n^2)\)。
算法 4
设 \(g(i, j)\) 表示 \(s_{g(i, j)}\) 为 \(j\),且 \([g(i, j) + 1, i]\) 是可以消除的最大位置。则 \(f(i) = f(g(i - 1, s_i) - 1) + 1\)。
接下来需要求 \(g(i, j)\),初始时 \(g(i, s_i) = i\),即没有消除。如果存在 \(g(i - 1, s_i) - 1\),那么 \(g(i, \Sigma)\) 都可以由 \(g(g(i - 1, s_i) - 1, \Sigma)\) 得到。
时间复杂度为 \(O(n \left| \Sigma \right|)\)。其中 \(\Sigma\) 是字符集。
#include <bits/stdc++.h>
typedef long long LL;
const int N = 2e6 + 5;
int n;
std::string s;
int f[N], g[N][26];
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cin >> n >> s;
s = ' ' + s;
LL ans = 0;
for (int i = 1; i <= n; i++) {
int x = s[i] - 'a';
if (g[i - 1][x] > 0) {
f[i] = f[g[i - 1][x] - 1] + 1;
for (int j = 0; j < 26; j++) {
g[i][j] = g[g[i - 1][x] - 1][j];
}
}
g[i][x] = i;
ans += f[i];
}
std::cout << ans << '\n';
return 0;
}
算法 5
用一个栈从 \(1\) 开始做类似括号匹配,显然我们需要做 \(n\) 次是因为当出现可消除序列时,栈不一定为空。如果我们记录 \(i\) 时栈的状态为 \(S_i\),如果 \(S_i\) 和 \(S_j\) 是相等的,说明 \([i, j]\) 可消除。将栈内元素 Hash 后用 std::map 记录下来,通过累加 \(S_i\) 的出现次数。时间复杂度为 \(O(n \log n)\)。如果使用 std::unordered_map,则时间复杂度最优为 \(O(n)\)。
#include <bits/stdc++.h>
typedef long long LL;
typedef unsigned long long ULL;
const int N = 2e6 + 5, base = 131;
int n;
std::string s;
std::map<ULL, int> mp; // std::unordered_map
std::stack<int> st;
ULL h[N];
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cin >> n >> s;
s = ' ' + s;
LL ans = 0; mp[0]++;
for (int i = 1; i <= n; i++) {
if (!st.empty() && s[st.top()] == s[i]) {
h[i] = h[st.top() - 1];
st.pop();
} else {
st.push(i);
h[i] = h[i - 1] * base + (s[i] - 'a' + 1);
}
ans += mp[h[i]];
mp[h[i]]++;
}
std::cout << ans << '\n';
return 0;
}

浙公网安备 33010602011771号