[WC/CTS 2024] T2 水镜 题解
这里好像是 \(O(nlog^2n)\) 的垃圾做法。
不过好像有 \(O(n)\) 的,但是蒟蒻太菜了不会 qwq。
约定
首先,这个 \(2 \times L\) 似乎没有什么作用,那下文就直接用 \(L\) 替代。
其次,令集合 \(H_i\) 为 \(\{h_i,L-h_i\}\),则 \(r_i\) 属于 \(H_i\)。
\(n \le 100\)
考虑先枚举区间 \([l,r]\),假设我们已知 \(L\),显然有一个设计 \(f_{i,0/1}\) 状态的 dp 方法进行判定。
列出转移方程可以发现,转移和 \(H_i\) 的具体数值无关,我们关心的是 \(H_i\) 和 \(H_{i-1}\) 这两个集合之间两两元素的大小关系。
也就是说,对于每个 \(i\),有关转移的只有 \(4\) 种类型。
-
\(h_{i-1} < h_i\),这种情况是平凡的;
-
\(L - h_{i-1} < h_i\),等价于 \(L < h_{i-1}+h_i\);
-
\(h_{i-1} < L - h_i\),等价于 \(L > h_{i-1}+h_i\);
-
\(L - h_{i-1} < L - h_i\),这与 1 相反,也是平凡的。
综上,我们可以发现对于每个 \(i\),\(L\) 只有一个分界值,即 \(h_{i-1}+h_i\)。
所以,本质不同的 dp 转移,一共就只有 \(r-l+1\) 中不同的情况。
我们可以考虑枚举这几种转移方式,再 dp 一遍,复杂度 \(O(n^4)\)。
\(n \le 4000\)
由上面的发现不难得出,本质不同的 \(L\),一共就只有 \(n\) 种,先考虑枚举 \(L\)。
考虑这个 dp 的图伦建模,就是若 \(f_{i-1,k_1}\) 可以转移到 \(f_{i,k_2}\),则 \((i-1,k_1)\) 就向 \((i,k_2)\) 连边,
则,合法当且仅当可以从 \((r,t_2)\) 走到 \((l,t_1)\)(即联通),其中 \(0\le k_1,k_2,t_1,t_2\le1\)。
所以,对于转移本质相同的 \(L\),建出的图形也必然相同。显然,如果固定右端点 \(r\),那么我们只需求出最左侧合法的左端点 \(p\),而合法左端点 \(l\) 的个数就有 \(r-p\) 种。
这样的话我们可以用并查集维护连通性,得到每个 \(r\) 的最左侧合法的左端点 \(p_r\),再对每个 \(L\) 的 \(p_r\) 取个 min 就可以得到每个右端点最远的合法左端点了。
复杂度就可以做到 \(O(n^2)\)。
\(n \le 500000\)
对于两个相邻的本质不同的 \(L\),其对应的无向图的变化并不大,就会有一些边从 \((i-1,1)\) 到 \((i,0 )\) 的边删除,加入一些 \((i-1,0)\) 到 \((i,1)\) 的边。
一共的加边删边操作有 \(O(n)\) 次,所以可以考虑 线段树分治。
然后用 待撤销并查集 维护连通性,可以发现每个并查集所维护的点一定是一段区间。
注意:这里的一段区间指的是这些点的第一关键字是连续的,与第二关键字无关,比如 \((1,0)\),\((2,1)\) 和 \((3,0)\) 。
合并并查集的时候不妨设两个区间是 \(l_x,r_x,l_y,r_y\) 且有 \(l_x \le l_y\),
然后顺便执行再 \([l_y,r_y]\) 中区间与 \(l_x\) 取 min(毕竟是联通了的)。
撤销的时候记得撤销 \(l_u,r_u\) 的值即可,最后直接统计答案就 100pts 了。
复杂度是 \(O(nlog^2n)\),可能要卡一卡才能过去(
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int rd() {
int x = 0, f = 1;
char ch = getchar();
while (!('0' <= ch && ch <= '9')) {
if (ch == '-') f = -1; ch = getchar();
}
while ('0' <= ch && ch <= '9') {
x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();
}
return x * f;
}
void wr(int x) {
if (x < 0) putchar('-'), x = -x;
if (x >= 10) wr(x / 10); putchar(x % 10 + '0');
}
const int N = 5e5 + 10;
struct tree {
int l, r; vector<pair<int,int> > e;
} t[N << 2];
int tot, pos; ll x[N];
int n, mn[30][N]; ll h[N];
int f[N << 1], l[N << 1], r[N << 1], rnk[N << 1];
pair<int*,int> opt[N << 3]; int cnt;
inline int find (int x) {
if (f[x] == x) return x;
return find(f[x]);
}
void sol_min (int l, int r, int v) {
// cout << l << " " << r << " " << v << endl;
if (r < l) return ;
int k = log2(r - l + 1);
mn[k][l] = min(mn[k][l], v); mn[k][r - (1 << k) + 1] = min(mn[k][r - (1 << k) + 1], v);
}
void get_min () {
for (int k = 20; k >= 1; --k) {
for (int i = 1; i <= n - (1 << k) + 1; ++i) {
mn[k - 1][i] = min(mn[k - 1][i], mn[k][i]);
mn[k - 1][i + (1 << (k - 1))] = min(mn[k - 1][i + (1 << (k - 1))], mn[k][i]);
}
}
}
inline void merge (int u, int v) {
// cout << u << " " << v << endl;
u = find(u), v = find(v);
// cout << u << " " << v << " " << l[u] << " " << l[v] << " " << r[u] << " " << r[v] << endl;
if (u == v) return ;
if (l[u] > l[v]) swap(u, v); sol_min(l[v], r[v], l[u]);
if (rnk[u] > rnk[v]) {
opt[++cnt] = make_pair (l + u, l[u]);
opt[++cnt] = make_pair (r + u, r[u]);
opt[++cnt] = make_pair (f + v, v); f[v] = u;
l[u] = min(l[u], l[v]); r[u] = max(r[u], r[v]);
} else {
opt[++cnt] = make_pair (l + v, l[v]);
opt[++cnt] = make_pair (r + v, r[v]);
opt[++cnt] = make_pair (f + u, u); f[u] = v;
l[v] = min(l[v], l[u]); r[v] = max(r[v], r[u]);
if (rnk[u] == rnk[v]) {
opt[++cnt] = make_pair (rnk + v, rnk[v]); rnk[v]++;
}
}
}
void build (int x, int l, int r) {
t[x].l = l; t[x].r = r; t[x].e.clear();
if (l == r) return ;
int mid = (l + r) >> 1;
build (x << 1, l, mid); build (x << 1 | 1, mid + 1, r);
}
void update (int x, int l, int r, pair<int,int>v) {
if (l <= t[x].l && t[x].r <= r) {
// ++pos;
t[x].e.push_back(v); return ;
} int mid = (t[x].l + t[x].r) >> 1;
if (l <= mid) update (x << 1, l, r, v);
if (r > mid) update (x << 1 | 1, l, r, v);
}
void solve (int x) {
int now = cnt; //cout << x << endl;
for (auto v : t[x].e) merge(v.first, v.second);
if (t[x].l != t[x].r) { solve (x << 1); solve (x << 1 | 1); }
while (cnt > now) *opt[cnt].first = opt[cnt].second, cnt--;
// cout << endl;
}
signed main() {
// freopen ("P10144_20.in", "r", stdin);
ios::sync_with_stdio(false); cin.tie(0);
double st = clock();
cin >> n;
for (int i = 1; i <= n; ++i) cin >> h[i];
for (int i = 1; i <= n; ++i) f[i] = i, l[i] = i, r[i] = i;
for (int i = n + 1; i <= 2 * n; ++i) f[i] = i, l[i] = i - n, r[i] = i - n;
for (int i = 1; i < n; ++i) x[++tot] = h[i] + h[i + 1];
sort (x + 1, x + tot + 1); tot = unique(x + 1, x + tot + 1) - x - 1;
build (1, 1, tot + 1);
for (int i = 1; i < n; ++i) {
int p = lower_bound (x + 1, x + tot + 1, h[i] + h[i + 1]) - x;
if (h[i] < h[i + 1]) merge (i, i + 1);
else if (h[i] > h[i + 1]) merge (i + n, i + n + 1);
update (1, 1, p, make_pair(i + n, i + 1)); update (1, p + 1, tot + 1, make_pair(i, i + n + 1));
} //cout << (clock() - st) / CLOCKS_PER_SEC << endl;
for (int k = 0; k <= 20; ++k) {
for (int i = 1; i <= n; ++i) mn[k][i] = 1e9 + 1;
}
for (int i = 1; i <= n; ++i) mn[0][i] = i;
solve(1);
get_min();
ll sum = 0; for (int i = 1; i <= n; ++i) sum += i - mn[0][i];
cout << sum << endl;
// cout << (clock() - st) / CLOCKS_PER_SEC << endl;
return 0;
}