P4370 [Code+#4]组合数问题2
知识点:堆,小技巧
原题面:Luogu
简述
定义两个组合数 \(a_i\choose b_i\) 与 \(a_j\choose b_j\) 不同,当且仅当 \(a_i\not = a_j\) 或 \(b_i\not ={b_j}\)。
给定参数 \(n,k\),要求选出 \(k\) 个不同的组合数 \(a_i\choose b_i\),满足 \(0\le b_i\le a_i\le n\),最大化它们的和。
输出它们的和 \(\bmod 10^9+ 7\) 的值。
\(1\le n\le 10^6\),\(1\le k\le 10^5\)。
分析
显然答案即为前 \(k\) 大的组合数的和。
显然最大的组合数为 \(n\choose \frac{n}{2}\),它一定需要被选择。
考虑次大的组合数的位置,显然只可能出现在以下四个位置:
\(n\choose \frac{n}{2}-1\),\(n\choose \frac{n}{2}+1\),\(n-1\choose \frac{n}{2}-1\),\(n-1\choose \frac{n}{2}\)。
然后可以发现更一般的规律,比组合数 \(a\choose b\) 小的最大的数只能出现在下列四个位置:
\(a\choose b-1\),\(a\choose b+1\),\(a-1\choose b-1\),\(a-1\choose b\)。
考虑使用元素降序的优先队列进行维护,初始时队列中仅有 \(n\choose \frac{n}{2}\)。
每次取出队首元素,将其加入答案,枚举四个次小位置加入队列,注意去重。
取出 \(k\) 个元素后即得答案。
还有个问题,组合数会很大,如何定义优先级比较方式。
考虑将组合数拆成下降幂,发现下面三个命题,互为充要条件:
预处理前缀 \(\log\) 值的和(即阶乘的 \(\log\) 值)即可比较两个下降幂的大小。
注意 priority_queue
奇怪的重载优先级。
代码
//知识点:堆,小技巧
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <queue>
#include <set>
#define LL long long
#define pr std::pair
#define mp std::make_pair
const LL kMod = 1e9 + 7;
const int kMaxn = 1e6 + 10;
const int en[5] = {0, 0, -1, -1};
const int em[5] = {-1, 1, 0, -1};
//=============================================================
struct Data {
int n, m;
double val;
//注意奇怪的重载
bool operator < (const Data &sec) const {
return val < sec.val;
}
};
int n, k;
LL ans, fac[kMaxn];
double LogFac[kMaxn];
std::priority_queue <Data> q;
std::set <pr <int, int> > In_queue, Hash;
//=============================================================
inline int read() {
int f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
void Chkmax(int &fir_, int sec_) {
if (sec_ > fir_) fir_ = sec_;
}
void Chkmin(int &fir_, int sec_) {
if (sec_ < fir_) fir_ = sec_;
}
LL QuickPow(LL x_, LL y_, LL mod_) {
LL ret = 1;
for (; y_; y_ >>= 1) {
if (y_ & 1) ret = ret * x_ % mod_;
x_ = x_ * x_ % mod_;
}
return ret;
}
double LogC(int n_, int m_) {
return LogFac[n_] - LogFac[m_] - LogFac[n_ - m_];
}
LL C(int n_, int m_) {
if (n_ == m_ || m_ == 0) return 1ll;
return fac[n_] *
QuickPow(fac[m_], kMod - 2, kMod) % kMod *
QuickPow(fac[n_ - m_], kMod - 2, kMod) % kMod;
}
//=============================================================
int main() {
n = read(), k = read();
fac[0] = 1;
for (int i = 1; i <= n; ++ i) {
fac[i] = 1ll * fac[i - 1] * i % kMod;
LogFac[i] = LogFac[i - 1] + log(1.0 * i);
}
q.push((Data) {n, n / 2, LogC(n, n / 2)});
In_queue.insert(mp(n, n / 2));
for (int i = 1; i <= k; ++ i) {
if (q.empty()) break;
Data top = q.top();
q.pop();
Hash.insert(mp(top.n, top.m));
In_queue.erase(mp(top.n, top.m));
ans = (ans + C(top.n, top.m)) % kMod;
for (int j = 0; j < 4; ++ j) {
int newn = top.n + en[j];
int newm = top.m + em[j];
if (newn < 0 || newm < 0 || newn < newm) continue ;
if (Hash.count(mp(newn, newm))) continue ;
if (In_queue.count(mp(newn, newm))) continue ;
In_queue.insert(mp(newn, newm));
q.push((Data) {newn, newm, LogC(newn, newm)});
Data now = q.top();
}
}
printf("%lld\n", ans);
return 0;
}
/*
#include <algorithm>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <queue>
#define LL long long
const LL kMod = 1e9 + 7;
const int kMaxn = 1e6 + 10;
//=============================================================
struct Data {
int n, m;
double val;
bool operator < (const Data &sec) const {
return val < sec.val;
}
};
int n, k;
LL ans, fac[kMaxn];
double LogFac[kMaxn];
std::priority_queue <Data> q;
//=============================================================
inline int read() {
int f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
void Chkmax(int &fir_, int sec_) {
if (sec_ > fir_) fir_ = sec_;
}
void Chkmin(int &fir_, int sec_) {
if (sec_ < fir_) fir_ = sec_;
}
LL QuickPow(LL x_, LL y_, LL mod_) {
LL ret = 1;
for (; y_; y_ >>= 1) {
if (y_ & 1) ret = ret * x_ % mod_;
x_ = x_ * x_ % mod_;
}
return ret;
}
double LogC(int n_, int m_) {
return LogFac[n_] - LogFac[m_] - LogFac[n_ - m_];
}
LL C(int n_, int m_) {
return fac[n_] *
QuickPow(fac[m_], kMod - 2, kMod) % kMod *
QuickPow(fac[n_ - m_], kMod - 2, kMod) % kMod;
}
//=============================================================
int main() {
n = read(), k = read();
fac[0] = 1;
for (int i = 1; i <= n; ++ i) {
fac[i] = 1ll * fac[i - 1] * i % kMod;
LogFac[i] = LogFac[i - 1] + log2(1.0 * i);
}
for (int i = 0; i <= n; ++ i) {
q.push((Data) {n, i, LogC(n, i)});
}
for (int i = 1; i <= k; ++ i) {
Data top = q.top(); q.pop();
ans = (ans + C(top.n, top.m)) % kMod;
q.push((Data) {top.n - 1, top.m, LogC(top.n - 1, top.m)});
}
printf("%lld\n", ans);
return 0;
}
*/