CF846F Random Query

Description

给定长为 \(n\) 的数列 \(a_i\),随机地选取两个数 \(l,r\in [1,n]\),如果 \(l>r\),则交换 \(l,r\)。求区间 \([l,r]\) 中不同数字个数的期望。

Solution

我们只需要求出所有区间内的不同数字个数和再除以区间个数就可以了。实际上对特定区间统计不同数的个数并不太好处理,不管怎么优化复杂度都不优于 \(O(n^2)\),因为逃不掉枚举区间。那如果不枚举区间呢?考虑到不同数值的数相互间是独立的,所以想到可以枚举数值,转而快速统计合法的区间个数。一个特定的数 \(x\) 会对一个区间做出贡献,当且仅当 \(x\) 在该区间出现过,而且不管出现多少次都只会做出 \(1\) 的贡献。那么问题就转化成求有多少区间包含数 \(x\)。考虑补集转换,求总区间个数减去不包含 \(x\) 的区间个数。如果新开一个数组 \(b\)\(b_i\)\(1\) 当且仅当 \(a_i=x\) 成立。那么不包含 \(x\) 的区间就是一个由 \(0\) 组成的段,个数就是该 \(0\) 段长度的平方,最后再把所有 \(0\) 段的贡献求和。这一步显然可以用线段树来维护,类似于最大子段和的思路,一个节点维护线段树上区间前缀 \(0\) 和后缀 \(0\) 的个数,以及该区间包含的所求区间个数。复杂度 \(O(|V|\log n)\)\(|V|\) 是值域。

#include<stdio.h>
#include<vector>
using namespace std;
#define N 1000007
#define ll long long
#define lid id<<1
#define rid id<<1|1

inline int read(){
    int x=0,flag=1; char c=getchar();
    while(c<'0'||c>'9'){if(c=='0') flag=0; c=getchar();}
    while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+c-48;c=getchar();}
    return flag? x:-x;
}

struct Node{
    ll ls,rs,sum;
}t[N<<2];

int n;
vector<int> V[N];
bool a[N];

inline void update(int id,int lf,int rf){
    int mid=(lf+rf)>>1;
    if(t[lid].ls==mid-lf+1) t[id].ls=t[lid].ls+t[rid].ls;
    else t[id].ls=t[lid].ls;
    if(t[rid].rs==rf-mid) t[id].rs=t[rid].rs+t[lid].rs;
    else t[id].rs=t[rid].rs;
    ll x=t[lid].rs,y=t[rid].ls,z=t[lid].rs+t[rid].ls;
    t[id].sum=t[lid].sum-x*x+t[rid].sum-y*y+z*z;
}

int pos;
void modify(int id,int lf,int rf){
    if(lf==rf){
        a[pos]^=1;
        if(!a[pos]) t[id].ls=t[id].rs=t[id].sum=1;
        else t[id].ls=t[id].rs=t[id].sum=0;
    }else{
        int mid=(lf+rf)>>1;
        if(pos<=mid) modify(lid,lf,mid);
        else modify(rid,mid+1,rf);
        update(id,lf,rf);
    }
}

void build(int id,int lf,int rf){
    if(lf==rf)
        t[id].ls=t[id].rs=t[id].sum=1;
    else{
        int mid=(lf+rf)>>1;
        build(lid,lf,mid);
        build(rid,mid+1,rf);
        update(id,lf,rf);
    }
}

int main(){
    n=read();
    for(int i=1;i<=n;i++)
        V[read()].push_back(i);
    ll ans=0;
    build(1,1,n);
    for(int i=1;i<N;i++){
        if(!V[i].size()) continue;
        for(int j=0;j<V[i].size();j++)
            pos=V[i][j],modify(1,1,n);
        ans+=1ll*n*n-t[1].sum;
        for(int j=0;j<V[i].size();j++)
            pos=V[i][j],modify(1,1,n);
    }
    printf("%.6lf",1.0*ans/(1ll*n*n));
}
posted @ 2021-02-22 21:00  Kreap  阅读(51)  评论(0编辑  收藏  举报