[PKUSC2018]真实排名——线段树+组合数

题目链接:

[PKUSC2018]真实排名

对于每个数$val$分两种情况讨论:

1、当$val$不翻倍时,那么可以翻倍的是权值比$\frac{val-1}{2}$小的和大于等于$val$的。

2、当$val$翻倍时,显然权值在$[val,val*2-1]$的都要翻倍,剩下可以翻倍的是权值比$val$小的和大于等于$2*val$的。

用权值线段树维护权值,剩下的就是一步组合数。注意对$val=0$的特判。

#include<set>
#include<map>
#include<queue>
#include<stack>
#include<cmath>
#include<vector>
#include<cstdio>
#include<bitset>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
const int mod=998244353;
const int INF=1000000000;
int fac[100010];
int inv[100010];
int sum[10000000];
int ls[10000000];
int rs[10000000];
int cnt;
int n,k;
int a[100010];
int root;
void change(int &rt,int l,int r,int k)
{
    if(!rt)
    {
        rt=++cnt;
    }
    sum[rt]++;
    if(l==r)
    {
        return ;
    }
    int mid=(l+r)>>1;
    if(k<=mid)
    {
        change(ls[rt],l,mid,k);
    }
    else
    {
        change(rs[rt],mid+1,r,k);
    }
}
int query(int rt,int l,int r,int L,int R)
{
    if(L>R||!rt)
    {
        return 0;
    }
    if(L<=l&&r<=R)
    {
        return sum[rt];
    }
    int mid=(l+r)>>1;
    int res=0;
    if(L<=mid)
    {
        res+=query(ls[rt],l,mid,L,R);
    }
    if(R>mid)
    {
        res+=query(rs[rt],mid+1,r,L,R);
    }
    return res;
}
int C(int n,int m)
{
    if(n<m||m<0)
    {
        return 0;
    }
    return 1ll*fac[n]*inv[m]%mod*inv[n-m]%mod;
}
int solve(int val)
{
    int ans=0;
    int num=0;
    int sum=0;
    if(!val)
    {
        num=n-1;
        ans=(ans+C(num,k))%mod;
        ans=(ans+C(num,k-1))%mod;
    }
    else
    {
        num+=query(root,0,INF,0,(val-1)/2);
        num+=query(root,0,INF,val,INF)-1;
        ans=(ans+C(num,k))%mod;
        num=0;
        num+=query(root,0,INF,2*val,INF);
        num+=query(root,0,INF,0,val-1);
        sum+=query(root,0,INF,val,val*2-1);
        ans=(ans+C(num,k-sum))%mod;
    }
    return ans;
}
int main()
{
    scanf("%d%d",&n,&k);
    fac[0]=fac[1]=inv[0]=inv[1]=1;
    for(int i=2;i<=n;i++)
    {
        fac[i]=1ll*fac[i-1]*i%mod;
        inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
    }
    for(int i=2;i<=n;i++)
    {
        inv[i]=1ll*inv[i]*inv[i-1]%mod;
    }
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&a[i]);
        change(root,0,INF,a[i]);
    }
    for(int i=1;i<=n;i++)
    {
        printf("%d\n",solve(a[i]));
    }
}
posted @ 2019-05-29 09:44  The_Virtuoso  阅读(352)  评论(0编辑  收藏  举报