三元组

三元组

给定 $n$ 个两两不同的正整数 $a_1,a_2, \dots ,a_n$。

请你计算共有多少个三元组 $(i,j,k)$ 能够同时满足:

  1. $i<j<k$
  2. $a_i>a_j>a_k$

输入格式

第一行包含整数 $n$。

第二行包含 $n$ 个整数 $a_1,a_2, \dots ,a_n$。

输出格式

一个整数,表示满足条件的三元组数量。

数据范围

前 $4$ 个测试点满足 $3 \leq n \leq 4$。
所有测试点满足 $3 \leq n \leq {10}^{6}$,$1 \leq a_i \leq {10}^{9}$。

输入样例1:

3
3 2 1

输出样例1:

1

输入样例2:

3
2 3 1

输出样例2:

0

输入样例3:

4
10 8 3 1

输出样例3:

4

输入样例4:

4
1 5 4 3

输出样例4:

1

 

解题思路

  先说一下当时比赛的思路,貌似有点复杂。当时想到的是先把问题规模缩小,只考虑两个数的情况。枚举每一个数,看看前面有多少个数大于当前的数,这个可以用树状数组来实现,对于第$i$个数$a_i$,前面($1 \sim i - 1$)比$a_i$大的数有$cnt[i]$个。

  现在两个数的情况知道了,求三元组也是一样,枚举每一个数,然后看看前面有多少个数大于当前数,不过这里不再是加上比当前数要大的数的个数了,而是应该加上这些数所对应的$cnt$,因为求的是三元组。也是用树状数组来实现,不过每次动态加的数不再是$1$,而是$cnt[i]$。

  AC代码如下:

 1 #include <bits/stdc++.h>
 2 using namespace std;
 3 
 4 typedef long long LL;
 5 
 6 const int N = 1e6 + 10;
 7 
 8 int a[N];
 9 int cnt[N];
10 LL tr[N];
11 int xs[N], sz;
12 
13 int lowbit(int x) {
14     return x & -x;
15 }
16 
17 void add(int x, int c) {
18     for (int i = x; i <= sz; i += lowbit(i)) {
19         tr[i] += c;
20     }
21 }
22 
23 LL query(LL x) {
24     LL ret = 0;
25     for (int i = x; i; i -= lowbit(i)) {
26         ret += tr[i];
27     }
28     return ret;
29 }
30 
31 int find(int x) {
32     int l = 1, r = sz;
33     while (l < r) {
34         int mid = l + r >> 1;
35         if (xs[mid] >= x) r = mid;
36         else l = mid + 1;
37     }
38     return l;
39 }
40 
41 int main() {
42     int n;
43     scanf("%d", &n);
44     for (int i = 1; i <= n; i++) {
45         scanf("%d", a + i);
46         xs[++sz] = a[i];
47     }
48     
49     sort(xs + 1, xs + sz + 1);
50     // sz = unique(xs + 1, xs + sz + 1) - xs - 1;   // 可以省去,因为题目保证每个数都不同
51     
52     for (int i = 1; i <= n; i++) {
53         int t = find(a[i]);
54         cnt[i] += i - 1 - query(t); // 求大于a[i]的数的个数
55         add(t, 1);  // 加1,这里求的是二元组
56     }
57     
58     LL ret = 0, s = 0;
59     memset(tr, 0, sizeof(tr));
60     for (int i = 1; i <= n; i++) {
61         int t = find(a[i]);
62         ret += s - query(t);    // 求所有大于a[i]的数所对应cnt数组
63         s += cnt[i];
64         add(t, cnt[i]); // 加cnt[i],这里求的是三元组
65     }
66     printf("%lld", ret);
67     
68     return 0;
69 }

  其实不用想得这么麻烦,我们可以枚举$j$,这时就是看看左边有多少个$i$满足$a_i > a_j$,右边有多少个$k$满足$a_k < a_j$。分别用树状数组枚举两次,最后对于第$i$个数左边比它大的个数$l[i]$,右边比它小的个数$r[i]$,那么对于第$i$个数满足条件的三元组个数就是$l[i] \times r[i]$。

  AC代码如下:

 1 #include <bits/stdc++.h>
 2 using namespace std;
 3 
 4 typedef long long LL;
 5 
 6 const int N = 1e6 + 10;
 7 
 8 int a[N];
 9 int xs[N], sz;
10 int tr[N];
11 int cnt[N];
12 
13 int find(int x) {
14     int l = 1, r = sz;
15     while (l < r) {
16         int mid = l + r >> 1;
17         if (xs[mid] >= x) r = mid;
18         else l = mid + 1;
19     }
20     return l;
21 }
22 
23 int lowbit(int x) {
24     return x & -x;
25 }
26 
27 void add(int x, int c) {
28     for (int i = x; i <= sz; i += lowbit(i)) {
29         tr[i] += c;
30     }
31 }
32 
33 int query(int x) {
34     int ret = 0;
35     for (int i = x; i; i -= lowbit(i)) {
36         ret += tr[i];
37     }
38     return ret;
39 }
40 
41 int main() {
42     int n;
43     scanf("%d", &n);
44     for (int i = 1; i <= n; i++) {
45         scanf("%d", a + i);
46         xs[++sz] = a[i];
47     }
48     
49     sort(xs + 1, xs + sz + 1);
50     // sz = unique(xs + 1, xs + sz + 1) - xs - 1;   // 可以省去,因为题目保证每个数都不同
51     
52     for (int i = 1; i <= n; i++) {
53         int t = find(a[i]);
54         cnt[i] += i - 1 - query(t);
55         add(t, 1);
56     }
57     
58     LL ret = 0;
59     memset(tr, 0, sizeof(tr));
60     for (int i = n; i; i--) {
61         int t = find(a[i]);
62         ret += 1ll * cnt[i] * query(t - 1);
63         add(t, 1);
64     }
65     
66     printf("%lld", ret);
67     
68     return 0;
69 }

 

参考资料

  AcWing 4709. 三元组(AcWing杯 - 周赛):https://www.acwing.com/video/4505/

posted @ 2022-10-23 09:49  onlyblues  阅读(214)  评论(0编辑  收藏  举报
Web Analytics