Codeforces Round #775 (Div. 2, based on Moscow Open Olympiad in Informatics)

Codeforces Round #775 (Div. 2, based on Moscow Open Olympiad in Informatics)

E

要找所有字典序小于t的方案数。

对于字典序问题一般枚举第一个小于模板串的地方

设位置i是第一个字典序小的地方,i之前的前缀都和模板串相同,叫i为关键点

显然i之后可以随意组合。

当确定关键点后,方案数就是剩下的数字的所有排列方案数。

高中数学告诉我们剩余数字的方案数是

\[r = \frac {tot!}{C_1!C_2!\dots C_i! \dots C_n!} \]

其中 \(C_i\) 表示数字 \(i\) 的个数

现在在关键点上放置一个字典序较模板串小的数字j

还剩 \(tot - 1\) 个数字,数字 j 的数量变为 \(C_j - 1\) ,

此时在关键点放数字 j 的方案数就是

\[\frac {(tot - 1)!}{C_1!C_2!\dots (C_j - 1)! \dots C_n!} \]

这个式子只要对 r 进行一定变形就可以得到

\[r * \frac {C_j}{tot} = \frac {(tot - 1)!}{C_1!C_2!\dots (C_j - 1)! \dots C_n!} \]

在关键点上可以放所有比模板串字典序小的数字,这个可以用前缀和求

\[r * \frac {\sum {C_j}}{tot} \]

其中数字 j 是所有满足字典序小于模板串当前位置的数字

上式表示关键点 i 上放所有字典序小于模板串当前位置的数字的方案数

我们只要枚举所有关键点,就可以得到答案

但当枚举最后一个位置时,没有计算刚好是模板串的情况,所有还要加一个特判。

另外,在打包计算时,我们得知道如何计算快速得到剩下满足要求的数字

这是一个求取动态前缀和的过程,用树状数组即可。

#include<bits/stdc++.h>
#define yes puts("yes");
#define inf 0x3f3f3f3f
#define ll long long
#define linf 0x3f3f3f3f3f3f3f3f
#define debug(x) cout<<"> "<< x<<endl;
#define ull unsigned long long
#define endl '\n'
#define lowbit(x) x&-x
#define int long long
using namespace std;
typedef pair<int,int> PII;
const int MAXN =10 + 2e5 ,mod= 998244353;
/* 
    加强版找字典序比较小的重排列方案数问题

    3 4
    1 2 2
    2 1 2 1

    1 2 2
    2 1 2
    ai C(k,a1) * C(k,a2) * C(k,a3) * C(k,a4) ``` 
    
 */
int ksm(int a,int b) {
    int ans = 1;
    a %= mod;
    while(b) {
        if(b & 1) ans = ans * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return ans;
}
int inv[MAXN],fac[MAXN],invfac[MAXN];
void init() {
    fac[0] = fac[1] = 1;
    invfac[0] = 1;
    inv[1] = 1;
    for(int i = 1;i < MAXN;i ++) {
        fac[i] = fac[i - 1] * i % mod;
    }
    for(int i = 2;i < MAXN;i ++) {
        inv[i] = (mod - mod / i) * inv[mod % i] % mod;
    }
    for(int i = 1;i < MAXN;i ++) {
        invfac[i] = invfac[i - 1] * inv[i] % mod;
    }
}
int tr[MAXN];
void update(int p,int v) {
    for(int i = p;i < MAXN;i += lowbit(i)) {
        tr[i] += v;
    }
}
int query(int p) {
    int ans = 0;
    for(int i = p;i > 0;i -= lowbit(i)) {
        ans += tr[i];
    }
    return ans;
}
void solve()
{    
    init();
    int N,M;
    cin >> N >> M;
    vector<int> c(MAXN,0),b(M + 1);
    for(int i = 1;i <= N;i ++) {
        int t;
        cin >> t;
        c[t] ++;
        update(t,1);
    }
    for(int i = 1;i <= M;i ++) {
        cin >> b[i];
    }

    int r = fac[N];
    for(int i = 1;i < MAXN;i ++) {
        if(c[i]) r = r * invfac[c[i]] % mod; 
    }

    int ans = 0;
    for(int i = 1;i <= min(N,M);i ++) {
        int t = b[i];
        ans += r * inv[N - i + 1] % mod * query(t - 1) % mod;
        ans %= mod;
        if(!c[t]) {
            break;
        }
        r = r * inv[N - i + 1] % mod * c[t] % mod;
        
        c[t] --;
        update(t,-1);
    }

    if(c == vector<int>(MAXN,0) && N < M) { 
        ans = (ans + 1) % mod;
    }

    cout << ans << endl;
}
signed main()
{
    ios::sync_with_stdio(false);cin.tie(nullptr);cout.tie(nullptr);

    //int T;cin>>T;
    //while(T--)
        solve();

    return 0;
}
posted @ 2022-03-15 15:16  Mxrurush  阅读(48)  评论(1编辑  收藏  举报