【CF】Sereja and Arcs

#include <bits/stdc++.h>
#define llong long long
using namespace std;
 
const int N = 1e5;
const int P = 1e9+7;
const llong INV2 = 5e8+4;
int nxt[N+3];
int lstpos[N+3];
int a[N+3];
int num[N+3];
int cnum[N+3];
llong tmp0[N+3],tmp1[N+3],tmp2[N+3];
vector<int> clrpos[N+3];
int n,m,B;
llong ans1,ans2a,ans2b,ans2c,ans0,ans;
 
llong C2(llong x) {return x*(x-1ll)/2ll%P;}
llong update(llong &x,llong y) {x = (x+y)%P;}
 
struct BITree
{
    llong tr[N+3]; int siz;
    void addval(int lrb,llong val)
    {
        while(lrb<=siz)
        {
            update(tr[lrb],val);
            lrb += (lrb&(-lrb));
        }
    }
    llong querysum(llong rb)
    {
        llong ret = 0ll;
        while(rb)
        {
            update(ret,tr[rb]);
            rb -= (rb&(-rb));
        }
        return ret;
    }
    void clear()
    {
        for(int i=0; i<=siz; i++) tr[i] = 0ll;
    }
} bit1,bit2;
 
void getans0() 
{
    llong cur = 0ll;
    for(int i=1; i<=m; i++)
    {
        llong tmp = C2(num[i]);
        update(ans0,cur*tmp%P);
        update(cur,tmp);
    }
}
 
void getans1() 
{
    llong tmp = 0ll; 
    for(int i=1; i<=n; i++)
    {
        update(ans1,(tmp-C2(cnum[a[i]]))*(num[a[i]]-cnum[a[i]]-1)); 
        update(tmp,(llong)cnum[a[i]]); 
        cnum[a[i]]++;
    }
}
 
void getans2a()  
{
    bit1.siz = n; bit1.clear(); llong cur = 0ll;
    for(int i=1; i<=n; i++) 
    {
        if(num[a[i]]<=B)
        {
            int tnum = 0;
            for(int j=nxt[i]; j; j=nxt[j])
            {
                llong tmp = cur-bit1.querysum(j)-C2(tnum)+P+P;
                update(ans2a,tmp);
                tnum++; 
            }
            for(int j=nxt[i]; j; j=nxt[j])  
            {
                cur++; 
                bit1.addval(j,1);
            }
        }
    }
}
 
void getans2b() 
{
    for(int i=1; i<=m; i++) 
    {
        if(num[i]>B) 
        {
            tmp1[0] = 0ll; for(int j=1; j<=n; j++) tmp1[j] = tmp1[j-1]+(a[j]==i?1:0);
            for(int j=1; j<=m; j++)  
            {
                if(num[j]<=B)
                {
                    llong cur = 0ll;
                    for(int k=0; k<clrpos[j].size(); k++)
                    {
                        int rb = clrpos[j][k];
                        llong tmp = (num[i]-tmp1[rb])*cur%P;
                        update(ans2b,tmp); 
                        update(cur,tmp1[rb]); 
                    }
                }
            }
        }
    }
}
 
void getans2c() 
{
    for(int i=1; i<=m; i++)  
    {
        if(num[i]>B)
        {
            tmp1[0] = 0; for(int j=1; j<=n; j++) tmp1[j] = tmp1[j-1]+(a[j]==i?1:0);
            for(int j=1; j<=m; j++) 
            {
                if(i==j) continue;
                llong cur1 = 0ll,cur2 = 0ll;
                for(int k=0; k<clrpos[j].size(); k++) 
                {
                    int ra = clrpos[j][k];
                    llong tmp = tmp1[ra]*tmp1[ra]%P*k%P;
                    update(ans2c,tmp);
                    tmp = tmp1[ra]*(-2ll*cur1-k)%P+P;
                    update(ans2c,tmp);
                    tmp = cur2+cur1+P;
                    update(ans2c,tmp);
                    update(cur2,tmp1[ra]*tmp1[ra]);
                    update(cur1,tmp1[ra]);
                }
            }
        }
    }
    ans2c = ans2c*INV2%P;
}
 
int main()
{
    scanf("%d",&n); B = sqrt(n)/2;
    for(int i=1; i<=n; i++) scanf("%d",&a[i]),num[a[i]]++,m = max(m,a[i]),clrpos[a[i]].push_back(i);
    for(int i=1; i<=n; i++)
    {
        nxt[i] = lstpos[a[i]];
        lstpos[a[i]] = i;
    }
    getans0();
    getans1();
    getans2a();
    getans2b();
    getans2c();
    ans = ((ans0-ans1-ans2a-ans2b-ans2c)%P+P)%P;
    printf("%lld\n",ans);
    return 0;
}

 

posted @ 2020-05-29 19:04  ywwywwyww  阅读(438)  评论(0编辑  收藏  举报