P8349 [SDOI/SXOI2022] 整数序列
记 \(len_x\) 为 \(x\) 在 \(a\) 中出现的次数,显然有 \(\mathcal O(nq)\) 的暴力,拿下 \(20\) 分。
感觉用数据结构难以维护,考虑根号做法。
根号做法有一个阈值 \(L\),然后分讨(钦定 \(len_x \le len_y\)):
- \(len_x \le len_y \le L\)
直接暴力做,时间复杂度 \(\mathcal O(L)\)。
- \(L < len_x \le len_y\)
这样的 \(x, y\) 有 \(\mathcal O(\dfrac{n^2}{L^2})\) 种,做一次的时间复杂度是 \(\mathcal O(L)\),预处理的时间复杂度是 \(\mathcal O(\dfrac{n^2}L)\),此后 \(\mathcal O(\log L)\) 查询就好了(map)。
- \(len_x \le L < len_y\)
只有 \(\sum\limits_{i = l}^r c_i = 0\) 的区间才对答案有贡献,显然有不少无用的 \(y\)。
我们对每个 \(x\) 在左右各保留 \(2\) 个 \(y\),不难发现这样的序列和原序列等价,用 set 转化然后暴力做,单次时间复杂度是 \(\mathcal O(L \log n)\),总时间复杂度就是 \(\mathcal qL \log n\)。
为什么要两个 \(y\)?
每个 \(x\) 只能向左或向右匹配一个 \(y\),但是如果只保留一个 \(y\) 的话,会出现这样的情况:
yyxxyyyyyyyyyyxxyy
变为了yyxxyyxxyy
,原序列上最长只能选到xxyy
,但新序列上最长能选到xxyyxxyy
。这启示我们 \(y\) 还具有分割作用,所以再多保留一个 \(y\) 即可。
忽略小的时间复杂度,由均值不等式得 \(\dfrac{n^2}L + qL \log n \ge 2n\sqrt{q \log n}\),当且仅当 \(L = \dfrac{n^2}{q \log n}\) 时取等。
于是时间复杂度就是 \(\mathcal O(n \sqrt{q \log n})\)。
代码:
#include <bits/stdc++.h>
using namespace std;
typedef pair<int, int> pii;
typedef long long ll;
constexpr int N = 3e5 + 10, L = 500;
constexpr ll inf = 9e18;
int n, q, b[N];
vector<int> p[N];
set<int> s[N];
namespace SS {
int cnt, a[N], c[N];
unordered_map<int, ll> mn;
ll solve(int x, int y) {
cnt = 0; int i = 0, j = 0;
while (i < p[x].size() && j < p[y].size()) {
cnt++;
if (p[x][i] < p[y][j]) a[cnt] = p[x][i++], c[cnt] = 1;
else a[cnt] = p[y][j++], c[cnt] = -1;
}
while (i < p[x].size()) a[++cnt] = p[x][i++], c[cnt] = 1;
while (j < p[y].size()) a[++cnt] = p[y][j++], c[cnt] = -1;
int s = 0; ll sum = 0, ans = -inf;
mn.clear(), mn[0] = 0;
for (int i = 1; i <= cnt; i++) {
s += c[i], sum += b[a[i]];
if (mn.find(s) != mn.end()) {
ans = max(ans, sum - mn[s]);
mn[s] = min(mn[s], sum);
} else mn[s] = sum;
}
return ans;
}
}
namespace BB {
int cnt, a[N], c[N];
map<pii, ll> mem;
ll solve(int x, int y) {
if (mem.find(pii(x, y)) != mem.end()) return mem[pii(x, y)];
return mem[pii(x, y)] = SS::solve(x, y);
}
}
namespace SB {
int cnt;
struct Node {
int p, c;
bool operator<(const Node &rhs) const {return p < rhs.p;}
} a[N];
vector<set<int>::iterator> del;
unordered_map<int, ll> mn;
ll solve(int x, int y) {
cnt = 0;
for (int i : p[x]) {
a[++cnt] = {i, 1};
auto it = s[y].upper_bound(i);
if (it != s[y].begin()) {
it--, del.emplace_back(it), a[++cnt] = {*it, -1};
if (it != s[y].begin()) it--, del.emplace_back(it), a[++cnt] = {*it, -1}, it++;
it++;
}
if (it != s[y].end()) {
del.emplace_back(it), a[++cnt] = {*it, -1}; it++;
if (it != s[y].end()) del.emplace_back(it), a[++cnt] = {*it, -1};
}
for (auto i : del) s[y].erase(i); del.clear();
}
sort(a + 1, a + cnt + 1);
int sc = 0; ll sum = 0, ans = -inf;
mn.clear(), mn[0] = 0;
for (int i = 1; i <= cnt; i++) {
sc += a[i].c, sum += b[a[i].p];
if (mn.find(sc) != mn.end()) {
ans = max(ans, sum - mn[sc]);
mn[sc] = min(mn[sc], sum);
} else mn[sc] = sum;
if (a[i].c == -1) s[y].emplace(a[i].p);
}
return ans;
}
}
int main() {
ios_base::sync_with_stdio(0); cin.tie(nullptr), cout.tie(nullptr);
cin >> n >> q;
for (int i = 1, a; i <= n; i++) cin >> a, p[a].emplace_back(i), s[a].emplace(i);
for (int i = 1; i <= n; i++) cin >> b[i];
while (q--) {
int x, y; cin >> x >> y;
if (p[x].size() > p[y].size()) swap(x, y);
if (p[y].size() <= L) cout << SS::solve(x, y) << '\n';
else if (p[x].size() > L) cout << BB::solve(x, y) << '\n';
else cout << SB::solve(x, y) << '\n';
}
return 0;
}