World is Exploding 树状数组+离散化

Given a sequence A with length n,count how many quadruple (a,b,c,d) satisfies: abcd,1a<bn,1c<dn,Aa<Ab,Ac>Ada≠b≠c≠d,1≤a<b≤n,1≤c<d≤n,Aa<Ab,Ac>Ad.

InputThe input consists of multiple test cases. 
Each test case begin with an integer n in a single line. 

The next line contains n integers A1,A2AnA1,A2⋯An. 
1n500001≤n≤50000 
0Ai1e90≤Ai≤1e9
OutputFor each test case,output a line contains an integer.Sample Input

4
2 4 1 3
4
1 2 3 4

Sample Output

1
0

因为只考虑相对大小关系,所以先将数据离散化,然后用树状数组记录前后比i大或者比i小的元素,求出所有个数,分类讨论重合情况,减去重合的元素
#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<sstream>
#include<algorithm>
#include<queue>
#include<vector>
#include<cmath>
#include<map>
#include<stack>
#include<fstream>
#include<set>
#include<memory>
#include<bitset>
#include<string>
#include<functional>
using namespace std;
typedef long long LL;
#define MAXN 500009
LL pre_max[MAXN], pre_min[MAXN], post_max[MAXN], post_min[MAXN];
LL a[MAXN], tmp[MAXN];
LL T[MAXN], n;
LL lowbit(LL x)
{
    return x&(-x);
}
void update(LL x)
{
    while (x <= MAXN)
    {
        T[x] += 1;
        x += lowbit(x);
    }
}
LL getsum(LL x)
{
    LL sum = 0;
    while (x > 0)
    {
        sum += T[x];    
        x -= lowbit(x);
    }
    return sum;
}
int main()
{
    while (scanf("%lld", &n) != EOF)
    {
        memset(T, 0, sizeof(T));
        for (LL i = 1; i <= n; i++)
            scanf("%lld", &a[i]);
        memcpy(tmp, a, sizeof(a));
        sort(tmp + 1, tmp + n + 1);
        LL len = unique(tmp + 1, tmp + n + 1) - tmp;
        for (LL i = 1; i <= n; i++)
            a[i] = lower_bound(tmp + 1, tmp + len + 1, a[i]) - tmp;
        for (LL i = 1; i <= n; i++)
        {
            update(a[i]);
            pre_min[i] = getsum(a[i] - 1);
            pre_max[i] = i - getsum(a[i]);
        }
        LL sum1, sum2;
        sum1 = sum2 = 0;
        for (int i = 1; i <= n; i++)
        {
            post_min[i] = getsum(a[i] - 1) - pre_min[i];
            post_max[i] = n - getsum(a[i]) - pre_max[i];
            sum1 += post_min[i];
            sum2 += post_max[i];
        }
        LL ans = sum1*sum2;
        for (int i = 1; i <= n; i++)
        {
            ans -= pre_min[i] * pre_max[i];
            ans -= pre_min[i] * post_min[i];
            ans -= post_max[i] * pre_max[i];
            ans -= post_max[i] * post_min[i];
        }
        printf("%lld\n", ans);
    }
}

 

posted @ 2017-08-19 09:16  joeylee97  阅读(127)  评论(0编辑  收藏  举报