[树状数组] HH的项链 题解
题目描述
HH有一串由各种漂亮的贝壳组成的项链。HH相信不同的贝壳会带来好运,所以每次散步完后,他都会随意取出一段贝壳,思考它们所表达的含义。
HH不断地收集新的贝壳,因此他的项链变得越来越长。
有一天,他突然提出了一个问题:某一段贝壳中,包含了多少种不同的贝壳?这个问题很难回答。。。因为项链实在是太长了。
于是,他只好求助睿智的你,来解决这个问题。
输入格式
第一行:一个整数N,表示项链的长度。
第二行:N个整数,表示依次表示项链中贝壳的编号(编号为0到1000000之间的整数)。
第三行:一个整数M,表示HH询问的个数。
接下来M行:每行两个整数,L和R(1 ≤ L ≤ R ≤ N),表示询问的区间。
输出格式
M行,每行一个整数,依次表示询问对应的答案。
样例
样例输入
6
1 2 3 4 3 5
3
1 2
3 5
2 6
样例输出
2
2
4
题解
区间查询,单点修改。。。这一看就是树状数组啊;
我们开一个树状数组t表示每个位置出现数的种类,很容易发现难点:一个区间中,同一个数出现的频率不固定;
于是,为了避免重复,我们开一个数组ma,ma[i] 代表i这个数在前面出现的位置,每次在遍历区间时判断这个数前面是否有值,如果有,那我们就“以新换旧”,在树状数组中删去原来位置的数(即-1),并且保证每次都将新位置的数的种类+1(这两步即单点修改),最后按区间顺序进行m次区间查询(树状数组前缀和)即可;
为了实现上述思路,我们在预处理时需要将每个区间按右端点升序排列,这样就可以将上一步的ma继承下来接着用;
但还会有一个问题:如果ma被重复更新了,那该咋办?
定义la表示上一个区间右端点+1,每次更新ma时从la开始循环,一直到区间右端点,这样就可以避免重复更新;
证明此做法的正确性:
考虑两个区间,我们发现会有3种情况:
-
第一个区间完全包括于第二个,此时会更新la~第二个区间右端点的值,不会更新前面的值;
-
第二个与第一个交叉,此时只会更新它两个区间不重叠且属于第二个区间的部分;
-
第二个在第一个右边,此时从第一个的右端点+1到第二个的端点会更新且不会更新前面;
综上,此做法正确;
如果不升序排列的话,那和暴力没啥区别;
剩下的细节处理请看代码;
代码
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
int n, m;
int a[1000005];
int t[1000005];
int ma[1000005]; //ma[i]表示i上一次出现的位置;
int ans[1000005];
struct sss{
int x, y, id;
}e[1000005];
bool cmp(sss a, sss b) { //升序排序;
return a.y < b.y;
}
int lowbit(int x) {
return x & (-x);
}
void add_dian(int x, int k) { //单点修改;
while(x <= n) {
t[x] += k;
x += lowbit(x);
}
}
int ask_sum(int l, int r) { //区间查询;
int ans = 0;
int i = l - 1;
while(i > 0) {
ans -= t[i];
i -= lowbit(i);
}
i = r;
while(i > 0) {
ans += t[i];
i -= lowbit(i);
}
return ans;
}
int main() {
cin >> n;
memset(ma, 0, sizeof(ma));
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
cin >> m;
for (int i = 1; i <= m; i++) {
cin >> e[i].x >> e[i].y;
e[i].id = i; //为了最后按区间顺序输出;
}
sort(e + 1, e + 1 + m, cmp); //排序会更改原顺序,所以要id;
int la = 1; //每次只需要更新增多区间的值,la表示的是上一次区间终点;
for (int i = 1; i <= m; i++) {
for (int j = la; j <= e[i].y; j++) {
if (ma[a[j]]) {
add_dian(ma[a[j]], -1);
}
add_dian(j, 1); //每一次都把新位置存进来;
ma[a[j]] = j;
}
la = e[i].y + 1;
ans[e[i].id] = ask_sum(e[i].x, e[i].y); //前面所有种类和即为解;
}
for (int i = 1; i <= m; i++) {
cout << ans[i] << endl;
}
return 0;
}