Atcoder CODE FESTIVAL 2016 qual C 的E题 Encyclopedia of Permutations

题意:

对于一个长度为n的排列P,如果P在所有长度为n的排列中,按照字典序排列后,在第s位,则P的value为s

现在给出一个长度为n的排列P,P有一些位置确定了,另外一些位置为0,表示不确定。

现在问,P的所有可能的排列的value之和

n <= 500000

思路:

对于一个可能的排列,它的value为所有小于它的排列的个数 + 1反过来,对于一个排列a,如果P的可能的排列中有sum个排列大于a,则a对答案的贡献为sum

那我们就可以枚举位数,

一位一位的考虑:

对于2个排列P,b,我们假设它在第i位分出大小,即[1,i-1]的位置2个排列相同,并且第i位有P[i]  > b[i]

并且P就是满足条件的可能的排列,那我们算出这个时候有x个可能的P,y个可能的b,则对答案的贡献

为x * y

为了方便,我们从后面往前面枚举

那对于位数i,我们只需要分2 * 2种情况考虑:

自由的数表示没有被固定位置的数

suf表示i后面有多少个没有被确定的位置

sum表示P一共有多少个没有被确定的位置

1. Pi的位置确定了

1.1 b在i处放的数为[1,P[i]-1]中被固定在[i+1,n]中某一位的数,设有x个,则贡献:

   x * sum! * (n - i)!

  求x的话用一个树状数组bit记录就可以了,遇见固定的数就扔进bit里面

1.2   b在i处放的数为[1,P[i]-1]中自由的数,设有y个,则贡献:

  y * suf * (sum - 1)! * (n - i)!

  提前把所有自由的数放到一个树状数组bit2中,就可以快速得到y了

2 P[i]处的数没有确定

2.1 b在i处放的数也是自由的数,则贡献:

  C(sum,2) * (sum - 2)! * suf * (n - i)!

2.2 b在i处放的数是[1,P[i]-1]中被固定在i后面某一位的数

  假设P[i]放的是x,则b[i]放的应该是[1,x-1]中被固定在i后面的数,设[1,x-1]中被固定在i后面的数

  有y个,则贡献:

  y * (sum - 1)! * (n - i)!

  所以这部分的贡献需要我们枚举x,对于每一个x求有多少个y?

  不用,用线段树可以维护,总贡献为:

  seg[1].s * (sum - 1)! * (n - i)!

 

所以这道题目就搞定了,接下来说说线段树维护的是什么

Seg{int n;int ly;LL s}

线段树只需要维护自由的那一些数,因为枚举的x是自由的数

对于叶子节点i

s表示表示目前被固定的数中,比i小的有多少个

n表示这个叶子节点是不是自由的数

则对于所有节点:

n表示这个区间有多少个自由的数

s表示这个区间的自由的数的s之和

 

 

代码:

                                            
  //File Name: E.cpp
  //Author: long
  //Mail: 736726758@qq.com
  //Created Time: 2016年10月26日 星期三 10时29分59秒
                                   
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <iostream>
#include <map>
#include <set>
#include <math.h>
#include <vector>
#define LL long long
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
using namespace std;
const int MAXN  = 500000 + 5;
const int MOD = (int)1e9 + 7;
int p[MAXN],bit1[MAXN],bit2[MAXN],n;
bool use[MAXN];
LL jie[MAXN];
void update(int x,int add,int *bit){
    for(int i=x;i<=n;i+=i&-i)
        bit[i] += add;
}
int query(int x,int *bit){
    int res = 0;
    for(int i=x;i>0;i-=i&-i)
        res += bit[i];
    return res;
}
struct Seg{
    int n,ly;
    LL s;
}seg[MAXN << 2];
void pushup(int rt){
    seg[rt].s = seg[rt<<1].s + seg[rt<<1|1].s;
    seg[rt].n = seg[rt<<1].n + seg[rt<<1|1].n;
}
void pushdown(int rt){
    if(seg[rt].ly){
        int &ly = seg[rt].ly;
        int L = rt<<1,R = rt<<1|1;
        seg[L].ly += ly,seg[R].ly += ly;
        seg[L].s += (LL)seg[L].n * ly;
        seg[R].s += (LL)seg[R].n * ly;
        ly = 0;
    }
}
void build(int l,int r,int rt){
    seg[rt].s = seg[rt].ly = seg[rt].n = 0;
    if(l == r){
        if(!use[l]) seg[rt].n = 1;
        return ;
    }
    int m = l + r >> 1;
    build(lson);
    build(rson);
    pushup(rt);
}
void update(int L,int R,int add,int l,int r,int rt){
    if(L <= l && R >= r){
        seg[rt].ly += add;
        seg[rt].s += (LL)add * seg[rt].n;
        return ;
    }
    pushdown(rt);
    int m = l + r >> 1;
    if(L <= m) update(L,R,add,lson);
    if(R > m) update(L,R,add,rson);
    pushup(rt);
}
int init(){
    build(1,n,1);
    jie[0] = 1;
    for(int i=1;i<=n;i++)
        jie[i] = jie[i-1] * i % MOD;
    int sum = 0;
    for(int i=1;i<=n;i++){
        if(!use[i]){
            update(i,1,bit2);
            sum++;
        }
    }
    return sum;
}
LL solve(){
    int sum = init();
    LL ans = 0,tmp1,tmp2;
    int suf = 0;
    for(int i=n;i>0;i--){
        tmp1 = tmp2 = 0;
//        ans = 0;
        if(p[i]){
            int x = query(p[i] - 1,bit1);
            tmp1 = x * jie[sum] % MOD * jie[n - i] % MOD;
            int y = query(p[i],bit2);
            if(sum >= 1)
                tmp2 = (LL)y*suf % MOD * jie[sum-1] % MOD * jie[n-i] % MOD;
            ans = (ans + tmp1 + tmp2) % MOD;
            update(p[i],1,bit1);
            if(p[i] < n) update(p[i]+1,n,1,1,n,1);
        }
        else{
            if(sum >= 2)
                tmp1 = ((LL)sum * (sum - 1) / 2) % MOD * jie[sum-2] % MOD * suf % MOD * jie[n-i] % MOD;
            if(sum >= 1)
                tmp2 = seg[1].s % MOD * jie[sum-1] % MOD * jie[n-i] % MOD;
//            printf("tmp1 = %lld\ntmp2 = %lld\n",tmp1,tmp2);
            ans = (ans + tmp1 + tmp2) % MOD;
            suf++;
        }
//        printf("i = %d ans = %lld\n",i,ans);
    }
//    cout << jie[sum] << endl;
    ans = (ans + jie[sum]) % MOD;
    return ans;
}
int main(){
    scanf("%d",&n);
    memset(use,false,sizeof(use));
    for(int i=1;i<=n;i++){
        scanf("%d",p + i);
        use[p[i]] = true;
    }
    printf("%d\n",(int)solve());
    return 0;
}
View Code

 

posted on 2016-10-26 14:31  _fukua  阅读(348)  评论(0编辑  收藏  举报