[CF1083D]The Fair Nut’s getting crazy[单调栈+线段树]
题意
给定一个长度为 \(n\) 的序列 \(\{a_i\}\)。你需要从该序列中选出两个非空的子段,这两个子段满足
- 两个子段非包含关系。
- 两个子段存在交。
- 位于两个子段交中的元素在每个子段中只能出现一次。
求共有多少种不同的子段选择方案。输出总方案数对 \(10^9 + 7\) 取模后的结果。
需要注意的是,选择子段 \([a, b]\)、\([c, d]\) 与选择子段 \([c, d]\)、\([a, b]\) 被视为是相同的两种方案。
\(1 \leq n \leq 10^5, -10^9 \leq a_i \leq 10^9\)。
分析
-
考虑枚举一个区间 \([b,c]\) 作为交,记录 \(L_i,R_i\) 表示距离 \(i\) 最近的和 \(i\) 颜色相同的位置。
-
有: \(a\in[\max\limits_{i=b}^c{L_i},b),d\in(c,\min\limits_{i=b}^c{R_i}]\)。
-
记录可以取到的左端点的最小值(满足区间中不存在两个相同的数) \(pos\) 。 \(mi, mx\) 分别表示 \([j,i]\) 中 \(R\) 的极小值和 \(L\) 的极大值。
-
考虑从左到右枚举交区间的右端点 \(i\) ,用单调栈维护每个位置的 \(mi, mx\) 。容易得到以 \(i\) 为交区间的右端点的方案数为 \(\sum_{j=pos}^i(mi_j-i)(j-mx_j)\),拆开然后用线段树分别维护。
-
总时间复杂度为 \(O(nlogn)\)。
代码
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
#define go(u) for(int i = head[u], v = e[i].to; i; i=e[i].lst, v=e[i].to)
#define rep(i, a, b) for(int i = a; i <= b; ++i)
#define pb push_back
#define re(x) memset(x, 0, sizeof x)
inline int gi() {
int x = 0,f = 1;
char ch = getchar();
while(!isdigit(ch)) { if(ch == '-') f = -1; ch = getchar();}
while(isdigit(ch)) { x = (x << 3) + (x << 1) + ch - 48; ch = getchar();}
return x * f;
}
template <typename T> inline void Max(T &a, T b){if(a < b) a = b;}
template <typename T> inline void Min(T &a, T b){if(a > b) a = b;}
const int N = 1e5 + 7, mod = 1e9 + 7;
int n, vc;
LL ans;
int lst[N], L[N], R[N], V[N], a[N];
int st1[N], st2[N], tp1, tp2;
#define Ls o << 1
#define Rs (o << 1 | 1)
LL s1(int n) {
return 1ll * n * (n + 1) / 2;
}
LL ami[N << 2], amx[N << 2];
struct data {
LL mi, mx, smi, tm;
data operator +(const data &rhs) const {
return (data){ (mi + rhs.mi) % mod, (mx + rhs.mx) % mod, (smi + rhs.smi) % mod, (tm + rhs.tm) % mod};
}
}t[N << 2];
void add(LL &a, LL b) {
a += b;if(a >= mod) a -= mod;
}
void stmi(int l, int r, int o, int v) {
add(ami[o], v);
add(t[o].tm, 1ll * v * t[o].mx % mod);
add(t[o].mi, 1ll * (r - l + 1) * v % mod);
add(t[o].smi, (s1(r) - s1(l - 1)) % mod * v % mod);
}
void stmx(int l, int r, int o, int v) {
add(amx[o], v);
add(t[o].tm, 1ll * v * t[o].mi % mod);
add(t[o].mx, 1ll * (r - l + 1) * v % mod);
}
void pushdown(int l, int r, int o) {
int mid = l + r >> 1;
if(ami[o]) {
stmi(l, mid, Ls, ami[o]);
stmi(mid + 1, r, Rs, ami[o]);
}
if(amx[o]) {
stmx(l, mid, Ls, amx[o]);
stmx(mid + 1, r, Rs, amx[o]);
}
ami[o] = amx[o] = 0;
}
void pushup(int o) {
t[o] = t[Ls] + t[Rs];
}
void modify(int L, int R, int l, int r, int o, int v, int opt) {
if(L <= l && r <= R) {
if(!opt) stmi(l, r, o, v);
else stmx(l, r, o, v);
return;
}
pushdown(l, r, o);int mid = l + r >> 1;
if(L <= mid) modify(L, R, l, mid, Ls, v, opt);
if(R > mid) modify(L, R, mid + 1, r, Rs, v, opt);
pushup(o);
}
data query(int L, int R, int l, int r, int o) {
if(L <= l && r <= R) return t[o];
pushdown(l, r, o);int mid = l + r >> 1;
if(R <= mid) return query(L, R, l, mid, Ls);
if(L > mid) return query(L, R, mid + 1, r, Rs);
return query(L, R, l, mid, Ls) + query(L, R, mid + 1, r, Rs);
}
int main() {
n = gi();
rep(i, 1, n) a[i] = gi(), V[i] = a[i];
sort(V + 1, V + 1 + n);
vc = unique(V + 1, V + 1 + n) - V - 1;
rep(i, 1, n) a[i] = lower_bound(V + 1, V + 1 + vc, a[i]) - V;
rep(i, 1, n) {
L[i] = lst[a[i]] + 1;
lst[a[i]] = i;
}
rep(i, 1, vc) lst[i] = n + 1;
for(int i = n; i; --i) {
R[i] = lst[a[i]] - 1;
lst[a[i]] = i;
}
for(int i = 1, gg = 1; i <= n; ++i) {
for(; tp1 && L[i] >= L[st1[tp1]]; --tp1) {
modify(st1[tp1 - 1] + 1, st1[tp1], 1, n, 1, mod - L[st1[tp1]], 1);
}
modify(st1[tp1] + 1, i, 1, n, 1, L[i], 1);
st1[++tp1] = i;
for(; tp2 && R[i] <= R[st2[tp2]]; --tp2) {
modify(st2[tp2 - 1] + 1, st2[tp2], 1, n, 1, mod - R[st2[tp2]], 0);
}
modify(st2[tp2] + 1, i, 1, n, 1, R[i], 0);
st2[++tp2] = i;
Max(gg, L[i]);
data res = query(gg, i, 1, n, 1);
LL tmp = ((res.smi + i * res.mx % mod - res.tm - (s1(i) - s1(gg - 1)) % mod * i % mod) % mod + mod) % mod;
add(ans, tmp);
}
printf("%lld\n", ans);
return 0;
}