题解 P4370 [Code+#4]组合数问题2

某次钟神讲课后来补了这道题。

Description

简化题意:给定一个 \(n\),要求选出 \(k\) 个组合数 \(C_a^b\),是他们的和最大。其中 \(a,b\) 必须满足 \(0 \le a \le b \le n\)

Solution

先考虑如何找到最大的 \(k\) 个组合数。

我们知道组合数的递推式是 \(C_n^m = C_{n-1}^{m-1} + C_{n-1}^{m}\),并且这个式子就是杨辉三角的形式。

尝试把前几行写出来。

1
1 1
1 2  1
1 3  3  1
1 4  6  4 1
1 5 10 10 5  1
1 6 15 20 15 6 1

观察一下发现对于第 \(i\) 行,第 \(\frac{i}{2}\) 个数是最大的;对于每一列,越靠下越大。并且最大的是 \(C_n^{\frac{n}{2}}\)

那么考虑用广搜的思想,先把最大的 \(C_n^{\frac{n}{2}}\) 加进去,然后不断向周围三个方向扩展,取出前 \(k\) 大即可。

扩展的时候注意判断是否在界内。

注意判断该位置是否入队过,这里使用 map 标记判断。

但是我们忽视一个问题:我们并不能比较 \(C_{a}^{b}\) 的大小。因为我们求不出来,取模的话就会使得大小无法比较。如何解决?

根据钟神的思路想到高中的一个知识点:

\[\log (x \times y) = \log x + \log y \]

\[\log (\frac{x}{y}) = \log x - \log y \]

那么我们的 \(C_{a}^{b}\) 是不是也可以表示了?

\(Log_i = \sum_{j=1}^{i} \log j\),有:

\[\log C_a^b = Log_a - Log_b - Log_{a-b} \]

众所周知, \(f(x) = \log x\) 是单调递增函数,所以加入优队时比较 \(\log C_a^b\) 的大小就好了。

Code

/*
Work by: Suzt_ilymics
Problem: P4370 [Code+#4]组合数问题2
Knowledge: 优先队列,log部分知识 
Time: O(能过)
*/
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#include<cmath>
#include<map>
#define LL long long
#define orz cout<<"lkp AK IOI!"<<endl

using namespace std;
const int MAXN = 1e6+5;
const int INF = 1e9+7;
const int mod = 1e9+7;
int dx[] = {0, 0, -1, 0};
int dy[] = {0, -1, 0, 1};

struct node {
    int n, m; double val;
    bool operator < (const node &b) const { return val < b.val; }
};

int n, k; LL ans = 0;
double Log[MAXN]; // 开 double 防止精度问题 
int fac[MAXN], inv[MAXN];
priority_queue<node> q;
map<int, bool> Map[MAXN];

int read(){
    int s = 0, f = 0;
    char ch = getchar();
    while(!isdigit(ch))  f |= (ch == '-'), ch = getchar();
    while(isdigit(ch)) s = (s << 1) + (s << 3) + ch - '0' , ch = getchar();
    return f ? -s : s;
}

bool Check(int x, int y) { return x < 0 || y < 0 || y > x; }
LL calc(int n, int m) { return 1ll * fac[n] * inv[m] % mod * inv[n - m] % mod; }

void Init(int limit) { 
    fac[0] = 1, inv[0] = 1;
    fac[1] = 1, inv[1] = 1;
    for(int i = 1; i <= limit; ++i) Log[i] = Log[i - 1] + log(i); //预处理 Log 
    for(int i = 2; i <= limit; ++i) fac[i] = 1ll * fac[i - 1] * i % mod, inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod; // 预处理阶乘 
    for(int i = 2; i <= limit; ++i) inv[i] = 1ll * inv[i - 1] * inv[i] % mod; // 预处理阶乘的逆元 
}

void Solve() {
    q.push((node){n, n/2, Log[n] - Log[n/2] - Log[n - n/2]}); // 把最大的点加进去 
    Map[n][n/2] = true;
    for(int i = 1; i <= k; ++i) {
        node u = q.top(); q.pop();
//        cout<<u.n<<" "<<u.m<<" \n";
//        cout<<calc(u.n, u.m)<<"\n";
        ans = (ans + calc(u.n, u.m)) % mod;
        for(int j = 1; j <= 3; ++j) { // 枚举三个方向 
            int dn = u.n + dx[j], dm = u.m + dy[j];
            if(Check(dn, dm) || Map[dn][dm]) continue; // 判断是否出界及是否标记过 
            q.push((node){dn, dm, Log[dn] - Log[dm] - Log[dn - dm]});
            Map[dn][dm] = true;
        }
    }
}

int main()
{
    Init(1000000); 
    n = read(), k = read();
    Solve();
    printf("%lld", ans);
    return 0;
}
posted @ 2021-05-05 08:23  Suzt_ilymtics  阅读(112)  评论(4编辑  收藏  举报