第1届ICPC青少年程序设计竞赛 E.Game
题目描述
给定两个只由 \(0,1\) 组成的字符串 \(s,t\)。可以花费 \(v_i\) 的代价改 \(t\) 中的第 \(i\) 个字符。设改好的字符串为 \(t'\),则收入为 \(f(s,t')\)。\(f(s,t')\) 表示 \(t'\) 中是 \(s\) 的子串的最长前缀的长度。一种方案的价值是收入减去代价,求最大的价值。
\(1\le |s|,|t|\le 2\cdot 10^5,0\le v_i\le 50\)。
可以发现 \(v_i\) 的值域很小。考虑按 \(m\) 的大小分类。
若 \(m\le 50\),此时直接暴力取 \(\max\),时间复杂度 \(O(nm)\)。
若 \(m>50\),此时修改带来的收益一定比代价更大,那么就是能匹配多长就匹配多长。
设从 \(s\) 中的第 \(k\) 个位置开始匹配,则需要计算
\[\begin{aligned}
\sum_{i=1}^{m}v_i[s_{k+i-1}\neq t_i]&=\sum_{i=1}^{m}v_i(s_{k+i-1}-t_{i})^2\\
&=\sum_{i=1}^{m}v_i(s_{k+i-1}+t_i-2\cdot t_i\cdot s_{k+i-1})\\
&=\sum_{i=1}^{m}v_i\cdot s_{k+i-1}+\sum_{i=1}^{m}v_i\cdot t_i-2\sum_{i=1}^{m}v_i\cdot t_i\cdot s_{k+i-1}\\
&=\sum_{i=1}^{m}v_i(1-2\cdot t_i)\cdot s_{k+i-1}+C
\end{aligned}
\]
常数 \(C\) 可以在 \(O(n)\) 内解决。另一部分则是差卷积,总时间复杂度 \(O(n\log n)\)。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e6 + 5, mod = 998244353, G = 3, Gi = 332748118;
const double pi = acos(-1);
int T, n, m, lim, L, v[N], r[N]; ll ans, pre[N], f[N], g[N]; char s[N], t[N];
inline ll qpow(ll a, ll b) { ll ans = 1; for (; b; b >>= 1, a = a * a % mod) if (b & 1) ans = ans * a % mod; return ans; }
inline void NTT(ll *A, int type) {
for (int i = 0; i < lim; ++ i) if (i < r[i]) swap(A[i], A[r[i]]);
for (int mid = 1; mid < lim; mid <<= 1) {
ll Wn = qpow(type == 1 ? G : Gi, (mod - 1) / (mid << 1));
for (int j = 0; j < lim; j += (mid << 1)) {
ll w = 1;
for (int k = 0; k < mid; ++ k, w = w * Wn % mod) {
ll x = A[j + k], y = w * A[j + k + mid] % mod;
A[j + k] = (x + y) % mod, A[j + k + mid] = (x - y + mod) % mod;
}
}
}
if (type == 1) return ;
ll invs = qpow(lim, mod - 2);
for (int i = 0; i < lim; ++ i) A[i] = A[i] * invs % mod;
}
inline void solve() {
cin >> n >> m >> s + 1 >> t + 1, ans = 0;
for (int i = 1; i <= m; ++ i) cin >> v[i];
if (m <= 50) {
for (int i = 1; i <= n; ++ i) {
ll cost = 0;
for (int j = 1; j <= min(m, n - i + 1); ++ j) {
if (s[i + j - 1] != t[j]) cost += v[j];
ans = max(ans, (ll)j * j - cost);
}
}
return void(cout << ans << '\n');
}
m = min(m, n);
for (int i = 1; i <= m; ++ i) pre[i] = pre[i - 1] - (t[i] - 48) * v[i];
for (int i = max(n - 50, 1); i <= n; ++ i) {
ll cost = 0;
for (int j = 1; j <= min(m, n - i + 1); ++ j) {
if (s[i + j - 1] != t[j]) cost += v[j];
ans = max(ans, (ll)j * j - cost);
}
}
lim = 1, L = 0; while (lim <= n + m) lim <<= 1, ++ L;
for (int i = 0; i < lim; ++ i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
for (int i = 0; i < n; ++ i) f[i] = s[n - i] - 48;
for (int i = 1; i <= m; ++ i) g[i] = (2 * (t[i] - 48) * v[i] - v[i] + mod) % mod;
NTT(f, 1), NTT(g, 1);
for (int i = 0; i < lim; ++ i) f[i] = f[i] * g[i] % mod;
NTT(f, -1);
for (int i = 1; i <= n; ++ i) {
int len = min(m, n - i + 1);
ll sum = pre[len] + (ll)len * len, ad = f[n - i + 1];
if (ad > mod - 1e7) ad -= mod;
ans = max(ans, sum + ad);
}
cout << ans << '\n';
for (int i = 0; i < lim; ++ i) f[i] = g[i] = 0;
}
int main() {
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
for (cin >> T; T --; ) solve();
return 0;
}