洛谷 SP5542 / SPOJ CPAIR Counting pairs
题意
给定 \(N\) 个非负整数 \(A_1,A_2,...,A_N\) 和 \(Q\) 组询问 \((v_j,a_j,b_j)\),对于第 \(j\) 组询问,你需要回答满足 \(1 \le l \le r \le N\) 且 \(a_j \le r - l + 1 \le b_j\) 且 \(\sum\limits_{k=l}^r [A_k \ge v_j] = r - l + 1\) 的整数对 \((l,r)\) 的数量。
思路
考虑将询问离线,维护一个包含当前所有已经加进来的元素的极长连续区间的 set
。每次加进来一个元素就在 set
中找到它的前驱和后继,右/左端点下标相差 \(1\) 就合并。
如何计算每个极长区间的贡献?考虑一个长度为 \(x\) 的极长区间,它包含 \(x\) 个长度为 \(1\) 的子区间,\(x - 1\) 个长度为 \(2\) 的子区间,……,\(1\) 个长度为 \(x\) 的子区间。设长度为 \(i\) 的区间的答案为 \(ans_i\),则加进长度为 \(x\) 的极长区间就相当于 \(i \in [1,x],\ ans_i \gets ans_i + x - i + 1\)。将它拆成 \(ans_i \gets ans_i + x + 1\) 和 \(ans_i \gets ans_i - i\),树状数组/线段树维护区间和即可。若删除,则取相反数。
时间复杂度 \(O(N \log N + Q \log N)\)。
代码
code
/*
p_b_p_b txdy
AThousandMoon txdy
AThousandSuns txdy
hxy txdy
*/
#include <bits/stdc++.h>
#define pb push_back
#define fst first
#define scd second
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
typedef set<pii>::iterator sit;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;
inline int read() {
char c = getchar();
int x = 0;
for (; !isdigit(c); c = getchar()) ;
for (; isdigit(c); c = getchar()) x = (x << 1) + (x << 3) + (c ^ 48);
return x;
}
const int maxn = 100100;
int n, m;
ll ans[maxn], tree[maxn << 2], b[maxn], c[maxn], tag[maxn << 2];
set<pii> st;
struct node {
int v, i;
} a[maxn];
struct query {
int v, a, b, id;
} qq[maxn];
bool cmp(node a, node b) {
return a.v > b.v || (a.v == b.v && a.i < b.i);
}
bool cmp2(query a, query b) {
return a.v > b.v;
}
inline void update(int x, ll d) {
for (int i = x; i <= n; i += (i & (-i))) {
b[i] += d;
c[i] += d * x;
}
}
inline ll query(ll x) {
ll res = 0;
for (int i = x; i; i -= (i & (-i))) {
res += (x + 1) * b[i] - c[i];
}
return res;
}
inline void update(int l, int r, ll x) {
update(l, x);
update(r + 1, -x);
}
inline ll query(int l, int r) {
return query(r) - query(l - 1);
}
inline ll calc1(ll x) {
return x * (x + 1) / 2;
}
inline ll calc2(ll l, ll r) {
return calc1(r) - calc1(l - 1);
}
inline void pushup(int x) {
tree[x] = tree[x << 1] + tree[x << 1 | 1];
}
inline void pushdown(int x, int l, int r) {
if (!tag[x]) {
return;
}
int mid = (l + r) >> 1;
tree[x << 1] += tag[x] * calc2(l, mid);
tree[x << 1 | 1] += tag[x] * calc2(mid + 1, r);
tag[x << 1] += tag[x];
tag[x << 1 | 1] += tag[x];
tag[x] = 0;
}
void update(int rt, int l, int r, int ql, int qr, ll x) {
if (ql <= l && r <= qr) {
tree[rt] += x * calc2(l, r);
tag[rt] += x;
return;
}
pushdown(rt, l, r);
int mid = (l + r) >> 1;
if (ql <= mid) {
update(rt << 1, l, mid, ql, qr, x);
}
if (qr > mid) {
update(rt << 1 | 1, mid + 1, r, ql, qr, x);
}
pushup(rt);
}
ll query(int rt, int l, int r, int ql, int qr) {
if (ql <= l && r <= qr) {
return tree[rt];
}
pushdown(rt, l, r);
int mid = (l + r) >> 1;
ll res = 0;
if (ql <= mid) {
res += query(rt << 1, l, mid, ql, qr);
}
if (qr > mid) {
res += query(rt << 1 | 1, mid + 1, r, ql, qr);
}
return res;
}
inline void add(int x) {
// printf("add x: %d\n", x);
/*
if (x == 1002) {
puts("sdf");
for (pii p : st) {
printf("%d %d\n", p.fst, p.scd);
}
}
*/
update(1, x, x + 1);
update(1, 1, n, 1, x, -1);
}
inline void del(int x) {
// printf("del x: %d\n", x);
update(1, x, -x - 1);
update(1, 1, n, 1, x, 1);
}
void solve() {
n = read();
m = read();
for (int i = 1; i <= n; ++i) {
a[i].v = read();
a[i].i = i;
}
sort(a + 1, a + n + 1, cmp);
for (int i = 1; i <= m; ++i) {
qq[i].v = read();
qq[i].a = read();
qq[i].b = read();
qq[i].id = i;
}
sort(qq + 1, qq + m + 1, cmp2);
for (int i = 1, j = 0; i <= m; ++i) {
while (j < n && a[j + 1].v >= qq[i].v) {
++j;
// printf("j: %d\n", j);
// printf("idx: %d\n", a[j].i);
st.insert(make_pair(a[j].i, a[j].i));
add(1);
sit it = st.find(make_pair(a[j].i, a[j].i));
sit tit = it;
if (it != st.begin()) {
if ((--tit)->scd == a[j].i - 1) {
// printf("a: %d %d %d %d ", tit->fst, tit->scd, it->fst, it->scd);
int tmp = tit->fst;
del(tit->scd - tit->fst + 1);
st.erase(tit);
del(it->scd - it->fst + 1);
st.erase(it);
add(a[j].i - tmp + 1);
st.insert(make_pair(tmp, a[j].i));
// printf("%d %d\n", tmp, a[j].i);
tit = it = st.find(make_pair(tmp, a[j].i));
} else {
++tit;
}
}
if ((++tit) != st.end()) {
// printf("tit: %d %d\n", tit->fst, tit->scd);
if (tit->fst == a[j].i + 1) {
// printf("b: %d %d %d %d ", it->fst, it->scd, tit->fst, tit->scd);
int tmp1 = it->fst, tmp2 = tit->scd;
del(tit->scd - tit->fst + 1);
st.erase(tit);
del(it->scd - it->fst + 1);
st.erase(it);
add(tmp2 - tmp1 + 1);
// printf("%d %d\n", tmp1, tmp2);
st.insert(make_pair(tmp1, tmp2));
}
}
}
// printf("%lld %lld\n", query(qq[i].a, qq[i].b), query(1, 1, n, qq[i].a, qq[i].b));
ans[qq[i].id] = query(qq[i].a, qq[i].b) + query(1, 1, n, qq[i].a, qq[i].b);
/*
for (pii p : st) {
printf("%d %d\n", p.fst, p.scd);
}
putchar('\n');
*/
}
for (int i = 1; i <= m; ++i) {
printf("%lld\n", ans[i]);
}
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}