洛谷 SP5542 / SPOJ CPAIR Counting pairs

洛谷传送门

SPOJ 传送门

题意

给定 \(N\) 个非负整数 \(A_1,A_2,...,A_N\)\(Q\) 组询问 \((v_j,a_j,b_j)\),对于第 \(j\) 组询问,你需要回答满足 \(1 \le l \le r \le N\)\(a_j \le r - l + 1 \le b_j\)\(\sum\limits_{k=l}^r [A_k \ge v_j] = r - l + 1\) 的整数对 \((l,r)\) 的数量。

思路

考虑将询问离线,维护一个包含当前所有已经加进来的元素的极长连续区间的 set。每次加进来一个元素就在 set 中找到它的前驱和后继,右/左端点下标相差 \(1\) 就合并。

如何计算每个极长区间的贡献?考虑一个长度为 \(x\) 的极长区间,它包含 \(x\) 个长度为 \(1\) 的子区间,\(x - 1\) 个长度为 \(2\) 的子区间,……,\(1\) 个长度为 \(x\) 的子区间。设长度为 \(i\) 的区间的答案为 \(ans_i\),则加进长度为 \(x\) 的极长区间就相当于 \(i \in [1,x],\ ans_i \gets ans_i + x - i + 1\)。将它拆成 \(ans_i \gets ans_i + x + 1\)\(ans_i \gets ans_i - i\),树状数组/线段树维护区间和即可。若删除,则取相反数。

时间复杂度 \(O(N \log N + Q \log N)\)

代码

code
/*

p_b_p_b txdy
AThousandMoon txdy
AThousandSuns txdy
hxy txdy

*/

#include <bits/stdc++.h>
#define pb push_back
#define fst first
#define scd second

using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
typedef set<pii>::iterator sit;

#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;
inline int read() {
    char c = getchar();
    int x = 0;
    for (; !isdigit(c); c = getchar()) ;
    for (; isdigit(c); c = getchar()) x = (x << 1) + (x << 3) + (c ^ 48);
    return x;
}

const int maxn = 100100;

int n, m;
ll ans[maxn], tree[maxn << 2], b[maxn], c[maxn], tag[maxn << 2];
set<pii> st;

struct node {
	int v, i;
} a[maxn];

struct query {
	int v, a, b, id;
} qq[maxn];

bool cmp(node a, node b) {
	return a.v > b.v || (a.v == b.v && a.i < b.i);
}

bool cmp2(query a, query b) {
	return a.v > b.v;
}

inline void update(int x, ll d) {
	for (int i = x; i <= n; i += (i & (-i))) {
		b[i] += d;
		c[i] += d * x;
	}
}

inline ll query(ll x) {
	ll res = 0;
	for (int i = x; i; i -= (i & (-i))) {
		res += (x + 1) * b[i] - c[i];
	}
	return res;
}

inline void update(int l, int r, ll x) {
	update(l, x);
	update(r + 1, -x);
}

inline ll query(int l, int r) {
	return query(r) - query(l - 1);
}

inline ll calc1(ll x) {
	return x * (x + 1) / 2;
}

inline ll calc2(ll l, ll r) {
	return calc1(r) - calc1(l - 1);
}

inline void pushup(int x) {
	tree[x] = tree[x << 1] + tree[x << 1 | 1];
}

inline void pushdown(int x, int l, int r) {
	if (!tag[x]) {
		return;
	}
	int mid = (l + r) >> 1;
	tree[x << 1] += tag[x] * calc2(l, mid);
	tree[x << 1 | 1] += tag[x] * calc2(mid + 1, r);
	tag[x << 1] += tag[x];
	tag[x << 1 | 1] += tag[x];
	tag[x] = 0;
}

void update(int rt, int l, int r, int ql, int qr, ll x) {
	if (ql <= l && r <= qr) {
		tree[rt] += x * calc2(l, r);
		tag[rt] += x;
		return;
	}
	pushdown(rt, l, r);
	int mid = (l + r) >> 1;
	if (ql <= mid) {
		update(rt << 1, l, mid, ql, qr, x);
	}
	if (qr > mid) {
		update(rt << 1 | 1, mid + 1, r, ql, qr, x);
	}
	pushup(rt);
}

ll query(int rt, int l, int r, int ql, int qr) {
	if (ql <= l && r <= qr) {
		return tree[rt];
	}
	pushdown(rt, l, r);
	int mid = (l + r) >> 1;
	ll res = 0;
	if (ql <= mid) {
		res += query(rt << 1, l, mid, ql, qr);
	}
	if (qr > mid) {
		res += query(rt << 1 | 1, mid + 1, r, ql, qr);
	}
	return res;
}

inline void add(int x) {
	// printf("add x: %d\n", x);
	/*
	if (x == 1002) {
		puts("sdf");
		for (pii p : st) {
			printf("%d %d\n", p.fst, p.scd);
		}
	}
	*/
	update(1, x, x + 1);
	update(1, 1, n, 1, x, -1);
}

inline void del(int x) {
	// printf("del x: %d\n", x);
	update(1, x, -x - 1);
	update(1, 1, n, 1, x, 1);
}

void solve() {
	n = read();
	m = read();
	for (int i = 1; i <= n; ++i) {
		a[i].v = read();
		a[i].i = i;
	}
	sort(a + 1, a + n + 1, cmp);
	for (int i = 1; i <= m; ++i) {
		qq[i].v = read();
		qq[i].a = read();
		qq[i].b = read();
		qq[i].id = i;
	}
	sort(qq + 1, qq + m + 1, cmp2);
	for (int i = 1, j = 0; i <= m; ++i) {
		while (j < n && a[j + 1].v >= qq[i].v) {
			++j;
			// printf("j: %d\n", j);
			// printf("idx: %d\n", a[j].i);
			st.insert(make_pair(a[j].i, a[j].i));
			add(1);
			sit it = st.find(make_pair(a[j].i, a[j].i));
			sit tit = it;
			if (it != st.begin()) {
				if ((--tit)->scd == a[j].i - 1) {
					// printf("a: %d %d %d %d ", tit->fst, tit->scd, it->fst, it->scd);
					int tmp = tit->fst;
					del(tit->scd - tit->fst + 1);
					st.erase(tit);
					del(it->scd - it->fst + 1);
					st.erase(it);
					add(a[j].i - tmp + 1);
					st.insert(make_pair(tmp, a[j].i));
					// printf("%d %d\n", tmp, a[j].i);
					tit = it = st.find(make_pair(tmp, a[j].i));
				} else {
					++tit;
				}
			}
			if ((++tit) != st.end()) {
				// printf("tit: %d %d\n", tit->fst, tit->scd);
				if (tit->fst == a[j].i + 1) {
					// printf("b: %d %d %d %d ", it->fst, it->scd, tit->fst, tit->scd);
					int tmp1 = it->fst, tmp2 = tit->scd;
					del(tit->scd - tit->fst + 1);
					st.erase(tit);
					del(it->scd - it->fst + 1);
					st.erase(it);
					add(tmp2 - tmp1 + 1);
					// printf("%d %d\n", tmp1, tmp2);
					st.insert(make_pair(tmp1, tmp2));
				}
			}
		}
		// printf("%lld %lld\n", query(qq[i].a, qq[i].b), query(1, 1, n, qq[i].a, qq[i].b));
		ans[qq[i].id] = query(qq[i].a, qq[i].b) + query(1, 1, n, qq[i].a, qq[i].b);
		/*
		for (pii p : st) {
			printf("%d %d\n", p.fst, p.scd);
		}
		putchar('\n');
		*/
	}
	for (int i = 1; i <= m; ++i) {
		printf("%lld\n", ans[i]);
	}
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}
posted @ 2022-06-21 12:34  zltzlt  阅读(34)  评论(0编辑  收藏  举报