查询任意区间内不同元素的个数

链接:https://www.nowcoder.com/acm/contest/139/J
来源:牛客网

题目描述
Given a sequence of integers a1, a2, ..., an and q pairs of integers (l1, r1), (l2, r2), ..., (lq, rq), find count(l1, r1), count(l2, r2), ..., count(lq, rq) where count(i, j) is the number of different integers among a1, a2, ..., ai, aj, aj + 1, ..., an.
输入描述:

The input consists of several test cases and is terminated by end-of-file.
The first line of each test cases contains two integers n and q.
The second line contains n integers a1, a2, ..., an.
The i-th of the following q lines contains two integers li and ri.

输出描述:

For each test case, print q integers which denote the result.

示例1
输入

3 2
1 2 1
1 2
1 3
4 1
1 2 3 4
1 3

输出

2
1
3

备注:

* 1 ≤ n, q ≤ 105
* 1 ≤ ai ≤ n
* 1 ≤ li, ri ≤ n
* The number of test cases does not exceed 10.


题意 : 输入一串数字,每次询问任意一个区间内有多少个不同的元素。
思路分析:
1 . 主席树 (不过 TLE, 就当是练习了)
主席树,一直在线操作的数据结构,显然我们应该要以每个位置去构造主席树,当这个元素第一次出现时,我们将其插入到树中,当这个元素不是第一次出现时,我们先将上一次出现的位置减 1 ,当前的位置再加 1 即可。注意在这个过程是要不断的新建点去操作,而不是去更新以前的某一棵线段树的。

#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int maxn = 2e5+5;
const int mod = 1e9+7;
const double eps = 1e-9;
const double pi = acos(-1.0);
const int inf = 0x3f3f3f3f;

template <class _T> inline void in(_T &_a) {
	int _f=0,_ch=getchar();_a=0;
	while(_ch<'0' || _ch>'9'){if(_ch=='-')_f=1;_ch=getchar();}
	while(_ch>='0' && _ch<='9')_a=(_a<<1)+(_a<<3)+_ch-'0',_ch=getchar();
	if(_f)_a=-_a;
}

int n, q;
int pre[maxn];
struct node
{
    int l, r;
    int sum;
}t[maxn*40];
int cnt;
int root[maxn];

void init(){
    cnt = 1;
    root[0] = 0;
    t[0].l = t[0].r = t[0].sum = 0;
}
int mp[maxn];

void update(int num, int &rt, int add, int l, int r){
    t[cnt++] = t[rt];
    rt = cnt-1;
    t[rt].sum += add;
    
    if (l == r) return;
    int m = (l+r)>>1;
    if (num <= m) update(num, t[rt].l, add, l, m);
    else update(num, t[rt].r, add, m+1, r);
}

int query(int pos, int rt, int l, int r){    
    if (l >= pos) return t[rt].sum; 
     
    int m = (l+r)>>1;
    int ans = 0;
    if (pos <= m) ans += query(pos, t[rt].l, l, m);
    ans += query(pos, t[rt].r, m+1, r);
    return ans;
}

int main() {
    //freopen("in.txt", "r", stdin);
    //freopen("out.txt", "w", stdout);
    int x;
    int a, b;
    
    while(~scanf("%d%d", &n, &q)){
        for(int i = 1; i <= n; i++){
            in(pre[i]);
            pre[i+n] = pre[i];            
            mp[i] = mp[i+n] = -1;
        }
        n *= 2;
        
        init(); 
        for(int i = 1; i <= n; i++){ 
            if (mp[pre[i]] == -1) {
                root[i] = root[i-1];
                update(i, root[i], 1, 1, n);
            }
            else {
                int temp = root[i-1];
                update(mp[pre[i]], temp, -1, 1, n);
                root[i] = temp;
                update(i, root[i], 1, 1, n);
            }
            mp[pre[i]] = i;
        }
        int jie = n/2; 
        while(q--){
            scanf("%d%d", &a, &b);
            printf("%d\n", query(b, root[jie+a], 1, n));
        }
    }    
    return 0;
}

 2 . 莫队 (一个好的输入挂就可以了)

using namespace std;
#define ll long long
const int maxn = 2e5+5;
const int mod = 1e9+7;
const double eps = 1e-9;
const double pi = acos(-1.0);
const int inf = 0x3f3f3f3f;

int read() {                    //输入挂
    int x = 0, f = 1; register char ch = getchar();
    while (ch<'0' || ch>'9') { if (ch == '-')f = -1; ch = getchar(); }
    while (ch >= '0'&&ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
    return x*f;
} int n, q; int pre[maxn]; struct node { int id; int zu, l, r; bool operator< (const node &v){ if (zu == v.zu) return r < v.r; return l < v.l; } }a[maxn]; int cnt[maxn], num[maxn]; int ans; inline void add(int x){ if (cnt[pre[x]] == 0) ans++; cnt[pre[x]]++; } inline void remove(int x){ if (cnt[pre[x]] == 1) ans--; cnt[pre[x]]--; } int main() { //freopen("in.txt", "r", stdin); //freopen("out.txt", "w", stdout); int l, r; while(~scanf("%d%d", &n, &q)){ for(int i = 1; i <= n; i++){ in(pre[i]); pre[i+n] = pre[i]; cnt[i] = cnt[i+n] = 0; } int uu = n; n = n*2; int unit = sqrt(n); for(int i = 1; i <= q; i++){ in(l), in(r); int f = r/unit; if (r%unit) f++; a[i].id = i, a[i].zu = f, a[i].l = r, a[i].r = uu+l; } sort(a+1, a+1+q); int l=a[1].l, r= a[1].l-1; ans=0; for(int i = 1; i <= q; i++){ while(l < a[i].l) remove(l++); while(r > a[i].r) remove(r--); while(l > a[i].l) add(--l); while(r < a[i].r) add(++r); num[a[i].id] = ans; } for(int i = 1; i <= q; i++){ printf("%d\n", num[i]); } } return 0; }

  3 . 树状数组

树状数组维护的是 i 这个位置对应的数出现在哪里, 当对应的数第一次出现时,直接将此位置添加到进去,当之前出现过时,先将上一次出现的位置处 -1 , 再将此位置 +1 。

再直观的看一下就是 x1, x2, * , *, x5, x6。 * 表示此位置的元素在后面出现过,顾将其置 0 ,非 * 的位置表示的是此位置有一个不同的元素,树状数组内部维护的元素同时保证了唯一性,当要查询某一个区间时,只需要去计算一下 c[r] - c[l-1] 即可

代码示例:

#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int maxn = 2e5+5;
const int mod = 1e9+7;
const double eps = 1e-9;
const double pi = acos(-1.0);
const int inf = 0x3f3f3f3f;
 
int n, q;
int pre[maxn];
struct node
{
    int id;
    int l, r;
     
    bool operator< (const node &v){
        return r < v.r;
    }
}arr[maxn];
int last[maxn];
int c[maxn], vis[maxn];
 
int lowbit(int k) {return k&(-k);}
 
void add(int pos, int pt){
    for(int i = pos; i <= n; i += lowbit(i)){
        c[i] += pt;
    }
}
int ans[maxn];
 
int getsum(int x){
    int res = 0;
    for(int i = x; i > 0; i -= lowbit(i)){
        res += c[i];
    }
    return res;
}
 
int main() {
    //freopen("in.txt", "r", stdin);
    //freopen("out.txt", "w", stdout);
    int l, r;
     
    while(~scanf("%d%d", &n, &q)){
        for(int i = 1; i <= n; i++){
            scanf("%d", &pre[i]);
            pre[i+n] = pre[i];           
            c[i] = c[i+n] = 0;
            vis[i] = vis[i+n] = 0;
            last[i] = last[i+n] = 0;
        }
        for(int i = 1; i <= q; i++){
            scanf("%d%d", &l, &r);
            arr[i] = {i, r, l+n};   
        }
        n = n*2;
        sort(arr+1, arr+1+q);
        int k = 1;
        for(int i = 1; i <= n; i++){
            if (!vis[pre[i]]){
                add(i, 1);
            }
            else {
                add(last[pre[i]], -1);
                add(i, 1);
            }
            vis[pre[i]] = 1;
            while(i == arr[k].r){
                int num = getsum(arr[k].r)-getsum(arr[k].l-1);
                ans[arr[k].id] = num;
                k++;
            }
            last[pre[i]] = i;
        }
        for(int i = 1; i <= q; i++){
            printf("%d\n", ans[i]);
        }
    }
    return 0;
}

 

posted @ 2018-07-30 09:06  楼主好菜啊  阅读(1747)  评论(0编辑  收藏  举报