题解-gym102978C Count Min Ratio [*hard]
题面
给定 \(B\) 个蓝色的球和 \(R\) 个红色的球以及一个绿色的球,同颜色的球不可区分。对于一种球的排列方式,记 \(l_B\) 是绿球左边的蓝球个数,\(r_B\) 是绿右边的蓝球个数,\(l_R\) 是球左边的红球个数,\(r_R\) 是球右边边的红球个数,则该排列的权值是最大的正整数 \(x\) 满足 \(l_B \times x \le l_R\),\(r_B \times x\le r_R\)
数据范围:\(1 \le B \le 10^6\),\(1 \le R \le 10^{18}\)。
题解
法 1
考虑枚举绿球右边的红球和蓝球个数:
考虑右边那堆东西的组合意义:一条路径的权值是他经过的点中满足 \(y = Ax + P\) 的点数,要求所有从 \((0, 0)\) 出发,到达 \((B, R)\) 的路径的权值和。
首先我们先考虑一个前置的问题:
求从 \((0, 0)\) 到 \((W, AW + P)\) 的路径条数 (设为 \(f(W, A, P)\)) 满足这条路径不穿过 \(y = Ax+P\)。
考虑一条从 \((0, 0)\) 到 \((W, AW + P)\) 的路径,如果不穿过 \(y = Ax+P\),我们就枚举他经过这条直线的第一个位置:
考虑一条从 \((0, 0)\) 到 \((W - 1, AW + P + 1)\) 的路径,他必然穿过 \(y = Ax+P\):
我们发现 \((1) - A (2)\) 可得 \(f(W,A,P) = \binom{(A+1)W+P}{W} - A \binom{(A+1)W+P}{W-1}\)
回到现在要解决的问题:一条路径的权值是他经过的点中满足 \(y = Ax + P\) 的点数,要求所有从 \((0, 0)\) 出发,到达 \((W, H)\) 的路径的权值和。保证 \(AW+P \le H\) 记为 \(g(W, H, A, P)\)。
枚举路径上的点:\(\sum\limits_{i = 0}^{W} \binom{(A + 1) i+P}{i} \binom{W + H - (A+1) i - P}{W - i}\)。
考虑其组合意义,就是先从 \((0, 0)\) 走到 \((i, Ai+P)\),再在不经过 \(y = Ai+P\) 的情况下走到 \((W, H)\)。
可以把他想象成枚举最后一次 碰到 \(y = Ai+P\) 的位置,最终要到达 \((W, H + 1)\) (因为碰到之后必然会向上走,然后不能 穿过 \(y = Ai + P + 1\),因此要到达的点是 \((W, H + 1)\))。
其实我们算的就是从 \((0, 0)\) 到 \((H + 1, W)\) ,因此 \(g(W, H, A, P) - Ag(W - 1, H+1, A, P) = \binom{H+W+1}{W}\)。结果和 \(P\) 无关!
因此 \(g(W,H,A,P) = \sum\limits_{i = 0}^{W} \binom{H+W+1}{i} A^{W-i}\)
接下来就很好做了:要算的是 \(\sum\limits_{A = 1}^{\frac{R}{B}} (R-AB+1) \sum\limits_{i = 0}^{B} \binom{B+R+1}{i} A^{B-i}\)。
交换一下求和顺序就是 \(\sum\limits_{i = 0}^{B} \binom{H+R+1}{i} ( (R+1) \sum\limits_{A = 1}^{\frac{R}{B}} A^{B-i} - B \sum\limits_{A = 1}^{\frac{R}{B}} A^{B-i+1} )\)。可以伯努利数解决。
法 2
前置知识:广义二项级数
从这里开始推:
考虑如何计算后面的东西。
于是变成了和 法1 完全一样的形式了。
代码
#include<bits/stdc++.h>
#define L(i, j, k) for(int i = j, i##E = k; i <= i##E; i++)
#define R(i, j, k) for(int i = j, i##E = k; i >= i##E; i--)
#define ll long long
#define pii pair<int, int>
#define db double
#define x first
#define y second
#define ull unsigned long long
#define sz(a) ((int) (a).size())
#define vi vector<int>
using namespace std;
const int mod = 998244353, G = 3, iG = (mod + 1) / G, N = 2.1e6 + 7, inv2 = (mod + 1) / 2;
#define add(a, b) (a + b >= mod ? a + b - mod : a + b)
#define dec(a, b) (a < b ? a - b + mod : a - b)
inline ull calc(const ull &x) {
return x - (__uint128_t(x) * 9920937979283557439ull >> 93) * 998244353;
}
int qpow(int x, int y = mod - 2) {
int res = 1;
for(; y; x = (ll) x * x % mod, y >>= 1) if(y & 1) res = (ll) res * x % mod;
return res;
}
int n, m, fac[N], ifac[N], inv[N];
void init(int x) {
fac[0] = ifac[0] = inv[1] = 1;
L(i, 2, x) inv[i] = (ll) inv[mod % i] * (mod - mod / i) % mod;
L(i, 1, x) fac[i] = (ll) fac[i - 1] * i % mod, ifac[i] = (ll) ifac[i - 1] * inv[i] % mod;
}
int rt[N], Lim;
void Pinit(int x) {
for(Lim = 1; Lim <= x; Lim <<= 1) ;
int sG = qpow(G, (mod - 1) / Lim); rt[0] = 1;
L(i, 1, Lim) rt[i] = (ll) rt[i - 1] * sG % mod;
}
int C(int x, int y) {
return y < 0 || x < y ? 0 : (ll) fac[x] * ifac[y] % mod * ifac[x - y] % mod;
}
int rev[N];
void initrev(int n) {
L(i, 0, n - 1) rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) * (n >> 1)));
}
struct poly {
vector<int> a;
int size() { return sz(a); }
int & operator [] (int x) { return a[x]; }
int v(int x) { return x < 0 || x >= sz(a) ? 0 : a[x]; }
void clear() { vector<int> ().swap(a); }
void rs(int x = 0) { a.resize(x); }
poly (int n = 0) { rs(n); }
poly (vector<int> o) { a = o; }
poly (const poly &o) { a = o.a; }
poly Rs(int x = 0) { vi res = a; res.resize(x); return res; }
void ntt(int op, int t = true) {
int n = sz(a);
if(t) initrev(n);
L(i, 0, n - 1) if(rev[i] < i) swap(a[rev[i]], a[i]);
for(int i = 2; i <= n; i <<= 1)
for(int j = 0, l = (i >> 1), ch = Lim / i; j < n; j += i)
for(int k = j, now = 0; k < j + l; k++) {
int pa = a[k], pb = calc((ull) a[k + l] * (op == 1 ? rt[now] : rt[Lim - now]));
a[k] = add(pa, pb), a[k + l] = dec(pa, pb), now += ch;
}
if(op != 1) for(int i = 0, iv = qpow(n); i < n; i++) a[i] = (ll) a[i] * iv % mod;
}
friend poly operator * (poly aa, poly bb) {
if(!sz(aa) || !sz(bb)) return {};
int lim, all = sz(aa) + sz(bb) - 1;
for(lim = 1; lim < all; lim <<= 1);
initrev(lim), aa.rs(lim), bb.rs(lim), aa.ntt(1, false), bb.ntt(1, false);
L(i, 0, lim - 1) aa[i] = (ll) aa[i] * bb[i] % mod;
aa.ntt(-1, false), aa.a.resize(all);
return aa;
}
friend poly operator * (poly aa, int bb) {
poly res(sz(aa));
L(i, 0, sz(aa) - 1) res[i] = (ll) aa[i] * bb % mod;
return res;
}
friend poly operator + (poly aa, poly bb) {
vector<int> res(max(sz(aa), sz(bb)));
L(i, 0, sz(res) - 1) res[i] = add(aa.v(i), bb.v(i));
return poly(res);
}
friend poly operator - (poly aa, poly bb) {
vector<int> res(max(sz(aa), sz(bb)));
L(i, 0, sz(res) - 1) res[i] = dec(aa.v(i), bb.v(i));
return poly(res);
}
poly & operator += (poly o) {
rs(max(sz(a), sz(o)));
L(i, 0, sz(a) - 1) (a[i] += o.v(i)) %= mod;
return (*this);
}
poly & operator -= (poly o) {
rs(max(sz(a), sz(o)));
L(i, 0, sz(a) - 1) (a[i] += mod - o.v(i)) %= mod;
return (*this);
}
poly & operator *= (poly o) {
return (*this) = (*this) * o;
}
poly Inv() {
poly res, f, g;
res.rs(1), res[0] = qpow(a[0]);
for(int m = 1, pn; m < sz(a); m <<= 1) {
pn = m << 1, f = res, g.rs(pn), f.rs(pn), initrev(pn);
for(int i = 0; i < pn; i++) g[i] = (*this).v(i);
f.ntt(1, false), g.ntt(1, false);
for(int i = 0; i < pn; i++) g[i] = (ll) f[i] * g[i] % mod;
g.ntt(-1, false);
for(int i = 0; i < m; i++) g[i] = 0;
g.ntt(1, false);
for(int i = 0; i < pn; i++) g[i] = (ll) f[i] * g[i] % mod;
g.ntt(-1, false), res.rs(pn);
for(int i = m; i < min(pn, sz(a)); i++) res[i] = (mod - g[i]) % mod;
}
return res;
}
poly Integ() {
if(!sz(a)) return poly();
poly res(sz(a) + 1);
L(i, 1, sz(a)) res[i] = (ll) a[i - 1] * inv[i] % mod;
return res;
}
poly Deriv() {
if(!sz(a)) return poly();
poly res(sz(a) - 1);
L(i, 1, sz(a) - 1) res[i - 1] = (ll) a[i] * i % mod;
return res;
}
poly Ln() {
poly g = ((*this).Inv() * (*this).Deriv()).Integ();
return g.rs(sz(a)), g;
}
poly Exp() {
poly res(1), f;
res[0] = 1;
for(int m = 1, pn; m < sz(a); m <<= 1) {
pn = min(m << 1, sz(a)), f.rs(pn), res.rs(pn);
for(int i = 0; i < pn; i++) f[i] = (*this).v(i);
f -= res.Ln(), (f[0] += 1) %= mod, res *= f, res.rs(pn);
}
return res.rs(sz(a)), res;
}
poly pow(int x) {
poly res = (*this).Ln();
L(i, 0, sz(res) - 1) res[i] = (ll) res[i] * x % mod;
res = res.Exp();
return res;
}
poly sqrt() {
poly res(1), f;
res[0] = 1;
for(int m = 1, pn; m < sz(a); m <<= 1) {
pn = min(m << 1, sz(a)), f.rs(pn);
for(int i = 0; i < pn; i++) f[i] = (*this).v(i);
f += res * res, f.rs(pn), res.rs(pn), res = f * res.Inv(), res.rs(pn);
for(int i = 0; i < pn; i++) res[i] = (ll) res[i] * inv2 % mod;
}
return res;
}
void Rev() {
reverse(a.begin(), a.end());
}
} ;
poly Mul(poly aa, poly bb, int all = 0) {
if(!sz(aa) || !sz(bb)) return {};
if(!all) all = sz(aa) + sz(bb) - 1;
int lim; for(lim = 1; lim < all; lim <<= 1);
initrev(lim), aa.rs(lim), bb.rs(lim), aa.ntt(1, 0), bb.ntt(1, 0);
L(i, 0, lim - 1) aa[i] = calc((ull) aa[i] * bb[i]);
aa.ntt(-1, 0), aa.a.resize(all);
return aa;
}
int B, ns, now = 1;
ll R;
int main() {
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
cin >> R >> B, init(B + 2), Pinit(B * 2 + 4);
if(R < B) {
cout << "0\n";
return 0;
}
poly a(B + 2), b(B + 2);
L(i, 0, B + 1) a[i] = ifac[i + 1];
a = a.Inv(), now = 1;
L(i, 0, B + 1) now = (R / B + 1) % mod * now % mod, b[i] = (ll) now * ifac[i + 1] % mod;
a *= b;
L(i, 0, B + 1) a[i] = (ll) a[i] * fac[i] % mod;
(a[0] += mod - 1) %= mod;
now = 1;
L(i, 0, B)
(ns += ((R + 1) % mod * a[B - i] % mod + mod - (ll) B * a[B - i + 1]% mod) % mod * now % mod)
%= mod, now = (ll) (B + R + 1 - i) % mod * now % mod * inv[i + 1] % mod;
cout << ns << "\n";
return 0;
}