[BZOJ5016][Snoi2017]一个简单的询问

[BZOJ5016][Snoi2017]一个简单的询问

试题描述

给你一个长度为 \(N\) 的序列 \(a_i\)\(1 \leq i \leq N\)\(q\) 组询问,每组询问读入 \(l_1,r_1,l_2,r_2\),需输出

\begin{equation}
\sum_{x=0}^{\infty}get(l_1, r_1, x) \cdot get(l_2, r_2, x)
\notag
\end{equation}

\(get(l,r,x)\) 表示计算区间 \([l,r]\) 中,数字x出现了多少次。

输入

第一行,一个数字 \(N\),表示序列长度。

第二行,\(N\) 个数字,表示 \(a_1~a_N\)

第三行,一个数字 \(Q\),表示询问个数。

第4~\(Q+3\) 行,每行四个数字\(l_1,r_1,l_2,r_2\),表示询问。

\(N,Q \leq 50000\)

\(1 \leq i \leq N\)

\(1 \leq l_1 \leq r_1 \leq N\)

\(1 \leq l_2 \leq r_2 \leq N\)

注意:答案有可能超过int的最大值

输出

对于每组询问,输出一行一个数字,表示答案

输入示例

5
1 1 1 1 1
2
1 2 3 4
1 1 4 4

输出示例

4
1

数据规模及约定

见“输入

题解

\(Q([l_1, r_1], [l_2, r_2])\) 表示一个参数为 \(l_1, r_1, l_2, r_2\) 的询问的答案,那么

\(Q([l_1, r_1], [l_2, r_2])\)

\(\sum_{x=0}^{\infty} get(l_1, r_1, x) \cdot get(l_2, r_2, x)\)

\(= \sum_{x=0}^{\infty} (get(1, r_1, x) - get(1, l_1 - 1, x)) \cdot (get(1, r_2, x) - get(1, l_2 - 1, x))\)

\(= \sum_{x=0}^{\infty} get(1, r_1, x) \cdot get(1, r_2, x) - \sum_{x=0}^{\infty} get(1, r_1, x) \cdot get(1, l_2 - 1, x) - \sum_{x=0}^{\infty} get(1, l_1 - 1, x) \cdot get(1, r_2, x) + \sum_{x=0}^{\infty} get(1, l_1 - 1, x) \cdot get(1, l_2 - 1, x)\)

\(= Q([1, r_1], [1, r_2]) - Q([1, r_1], [1, l_2 - 1]) - Q([1, l_1 - 1], [1, r_2]) + Q([1, l_1 - 1], [1, l_2 - 1])\)

于是可以将一个四维的询问拆成 4 个二维的询问,上二维莫队即可。

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
#include <cmath>
using namespace std;

const int BufferSize = 1 << 16;
char buffer[BufferSize], *Head, *Tail;
inline char Getchar() {
	if(Head == Tail) {
		int l = fread(buffer, 1, BufferSize, stdin);
		Tail = (Head = buffer) + l;
	}
	return *Head++;
}
int read() {
	int x = 0, f = 1; char c = Getchar();
	while(!isdigit(c)){ if(c == '-') f = -1; c = Getchar(); }
	while(isdigit(c)){ x = x * 10 + c - '0'; c = Getchar(); }
	return x * f;
}

#define maxn 200010
#define LL long long

int n, q, A[maxn], bl[maxn];

struct Que { // a query in this format: query([1, l], [1, r])
	int l, r, id, tp;
	Que() {}
	Que(int _l, int _r, int _id, int _tp): l(_l), r(_r), id(_id), tp(_tp) {}
	bool operator < (const Que& t) const { return bl[l] != bl[t.l] ? bl[l] < bl[t.l] : r < t.r; }
} qs[maxn];

LL totl[maxn], totr[maxn], Ans[maxn];

int main() {
	n = read(); int m = sqrt(n + .5);
	for(int i = 1; i <= n; i++) A[i] = read(), bl[i] = (i - 1) / m + 1;
	int Q = read();
	for(int i = 1; i <= Q; i++) {
		int l1 = read(), r1 = read(), l2 = read(), r2 = read();
		q++; qs[q] = Que(r1, r2, i, 1);
		q++; qs[q] = Que(r1, l2 - 1, i, -1);
		q++; qs[q] = Que(l1 - 1, r2, i, -1);
		q++; qs[q] = Que(l1 - 1, l2 - 1, i, 1);
	}
	
	sort(qs + 1, qs + q + 1);
	int l = 1, r = 1;
	LL ans = 1; totl[A[1]] = totr[A[1]] = 1;
	for(int i = 1; i <= q; i++) {
		while(l < qs[i].l) l++, ans += totr[A[l]], totl[A[l]]++;
		while(l > qs[i].l) ans -= totr[A[l]], totl[A[l]]--, l--;
		while(r < qs[i].r) r++, ans += totl[A[r]], totr[A[r]]++;
		while(r > qs[i].r) ans -= totl[A[r]], totr[A[r]]--, r--;
		Ans[qs[i].id] += ans * qs[i].tp;
	}
	
	for(int i = 1; i <= Q; i++) printf("%d\n", Ans[i]);
	
	return 0;
}
posted @ 2017-10-03 18:36  xjr01  阅读(298)  评论(0编辑  收藏  举报