题解 P4370 [Code+#4]组合数问题2
某次钟神讲课后来补了这道题。
Description
简化题意:给定一个 \(n\),要求选出 \(k\) 个组合数 \(C_a^b\),是他们的和最大。其中 \(a,b\) 必须满足 \(0 \le a \le b \le n\) 。
Solution
先考虑如何找到最大的 \(k\) 个组合数。
我们知道组合数的递推式是 \(C_n^m = C_{n-1}^{m-1} + C_{n-1}^{m}\),并且这个式子就是杨辉三角的形式。
尝试把前几行写出来。
1
1 1
1 2 1
1 3 3 1
1 4 6 4 1
1 5 10 10 5 1
1 6 15 20 15 6 1
观察一下发现对于第 \(i\) 行,第 \(\frac{i}{2}\) 个数是最大的;对于每一列,越靠下越大。并且最大的是 \(C_n^{\frac{n}{2}}\)
那么考虑用广搜的思想,先把最大的 \(C_n^{\frac{n}{2}}\) 加进去,然后不断向周围三个方向扩展,取出前 \(k\) 大即可。
扩展的时候注意判断是否在界内。
注意判断该位置是否入队过,这里使用 map 标记判断。
但是我们忽视一个问题:我们并不能比较 \(C_{a}^{b}\) 的大小。因为我们求不出来,取模的话就会使得大小无法比较。如何解决?
根据钟神的思路想到高中的一个知识点:
\[\log (x \times y) = \log x + \log y
\]
\[\log (\frac{x}{y}) = \log x - \log y
\]
那么我们的 \(C_{a}^{b}\) 是不是也可以表示了?
设 \(Log_i = \sum_{j=1}^{i} \log j\),有:
\[\log C_a^b = Log_a - Log_b - Log_{a-b}
\]
众所周知, \(f(x) = \log x\) 是单调递增函数,所以加入优队时比较 \(\log C_a^b\) 的大小就好了。
Code
/*
Work by: Suzt_ilymics
Problem: P4370 [Code+#4]组合数问题2
Knowledge: 优先队列,log部分知识
Time: O(能过)
*/
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#include<cmath>
#include<map>
#define LL long long
#define orz cout<<"lkp AK IOI!"<<endl
using namespace std;
const int MAXN = 1e6+5;
const int INF = 1e9+7;
const int mod = 1e9+7;
int dx[] = {0, 0, -1, 0};
int dy[] = {0, -1, 0, 1};
struct node {
int n, m; double val;
bool operator < (const node &b) const { return val < b.val; }
};
int n, k; LL ans = 0;
double Log[MAXN]; // 开 double 防止精度问题
int fac[MAXN], inv[MAXN];
priority_queue<node> q;
map<int, bool> Map[MAXN];
int read(){
int s = 0, f = 0;
char ch = getchar();
while(!isdigit(ch)) f |= (ch == '-'), ch = getchar();
while(isdigit(ch)) s = (s << 1) + (s << 3) + ch - '0' , ch = getchar();
return f ? -s : s;
}
bool Check(int x, int y) { return x < 0 || y < 0 || y > x; }
LL calc(int n, int m) { return 1ll * fac[n] * inv[m] % mod * inv[n - m] % mod; }
void Init(int limit) {
fac[0] = 1, inv[0] = 1;
fac[1] = 1, inv[1] = 1;
for(int i = 1; i <= limit; ++i) Log[i] = Log[i - 1] + log(i); //预处理 Log
for(int i = 2; i <= limit; ++i) fac[i] = 1ll * fac[i - 1] * i % mod, inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod; // 预处理阶乘
for(int i = 2; i <= limit; ++i) inv[i] = 1ll * inv[i - 1] * inv[i] % mod; // 预处理阶乘的逆元
}
void Solve() {
q.push((node){n, n/2, Log[n] - Log[n/2] - Log[n - n/2]}); // 把最大的点加进去
Map[n][n/2] = true;
for(int i = 1; i <= k; ++i) {
node u = q.top(); q.pop();
// cout<<u.n<<" "<<u.m<<" \n";
// cout<<calc(u.n, u.m)<<"\n";
ans = (ans + calc(u.n, u.m)) % mod;
for(int j = 1; j <= 3; ++j) { // 枚举三个方向
int dn = u.n + dx[j], dm = u.m + dy[j];
if(Check(dn, dm) || Map[dn][dm]) continue; // 判断是否出界及是否标记过
q.push((node){dn, dm, Log[dn] - Log[dm] - Log[dn - dm]});
Map[dn][dm] = true;
}
}
}
int main()
{
Init(1000000);
n = read(), k = read();
Solve();
printf("%lld", ans);
return 0;
}