AGC021E Ball Eat Chameleons
AGC021E Ball Eat Chameleons
题目大意: 有 \(n\) 条变色龙,一开始都为蓝色,现在要求合法的长度为 \(k\) 的 RB 序列数量。对于一个序列,会按照序列顺序丢对应颜色的球,每一个球有随机一只变色龙吃下。如果这条变色龙吃的当前颜色球数量大于另一种颜色球数量,就会变色。求有多少序列,使得最终所有变色龙有可能变为红色。
数据范围:\(1\leq n,k\leq 5 \times 10^5\) 。
解题思路:考虑给定一个序列,怎样做才尽可能让所有变色龙变成红色。注意到红球和蓝球可以相互抵消,而让一个变色龙变色需要额外一个球的代价。那么对于一个红色的变色龙,如果没有吃过蓝球,可以免费吃一个,如果没有这样的变色龙,所有的蓝球都会丢到同一个变色龙上,因为丢在多个会付出多份代价。所以可以得到一个正确的策略,首先空出1号变色龙,对于一个红球,如果还有未丢过球的变色龙,就让它变色,否则丢给1号变色龙。而对于一个蓝球,如果有可以免费抵消的红色变色龙,则丢给它,否则丢给1号变色龙。
设红球有 \(R\) 个,蓝球有 \(B\) 个,这样子的策略1号变色龙吃到的红球数为 \(t=R-(n-1)\) ,此时如果 \(R < B\) ,会发现一号变色龙吃到的蓝球数量最少为 \(B-(n-1)\) 无解。
考虑当红球放了 \(n-1\) 个之前,1号变色龙吃到蓝球当且仅当没有红球可以免费抵消了,也就是蓝球匹配前面的红球无法匹配,记蓝球权值为 \(1\),红球权值为 \(-1\),那么 1号变色龙在红球放了 \(n-1\) 个之前吃到的蓝球数量为权值前缀和的 \(\max\) 。也就是说在这之前权值前缀和的 \(\max\) 必须 \(<t\) 。
考虑红球放了 \(n-1\) 个之后的情况,如果 \(R=B\) ,设之前吃到的蓝球数为 \(s\) ,显然有 \(s < t =B -(n-1)\) ,也就是剩下的蓝球数大于红球数,那么1号变色龙吃到的蓝球数一定为 \(t\) 个。由于要保证变色,最后一个球只能是蓝球,所以方案数就是任意时刻前缀 \(\max < t\) 且最后一个球是蓝球的方案数,即 \({k-1\choose R}-{k-1\choose R+t}\) 。
如果 \(R>B\) ,如果剩下的蓝球数能完全被前 \((n-1)\) 个红球抵消,那么显然是一个合法的序列,而这个序列是一定满足 \(\max < t\) 的,否则说明这 \((n-1)\) 个红球中的每一个都抵消了一个蓝球,而 \(R-(n-1)>B-(n-1)\) 一定是一个合法的序列,此时倒推可知满足 \(\max < t\) ,所以在这种情况下 \(\max < t\) 是合法的充要条件,方案数就是 \({k\choose R}-{k\choose R+t}\)。
方案数推导,方案数可以看做从 \((0,0)\) 走到 \((N,M)\) 且不能碰到直线 \(y-x=L\) 的方案数,那么对于一个不合法的方案,找到第一次碰到直线 \(y-x=L\) 的位置,并将之前的部分沿直线翻折,会等价于一个从 \((-L,L)\) 走到 \((N,M)\) 的方案数,所以就是两个组合数相减 \({N+M\choose N}-{N+M\choose N+L}\) 。
code
/*program by mangoyang*/
#pragma GCC optimize("Ofast", "inline")
#include<bits/stdc++.h>
#define inf (0x3f3f3f3f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
int ch = 0, f = 0; x = 0;
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
if(f) x = -x;
}
const int N = 1000005, mod = 998244353;
int js[N], inv[N], n, k;
inline int Pow(int a, int b){
int ans = 1;
for(; b; b >>= 1, a = 1ll * a * a % mod)
if(b & 1) ans = 1ll * ans * a % mod;
return ans;
}
inline int C(int x, int y){
if(x < 0 || y < 0 || x < y) return 0;
return 1ll * js[x] * inv[y] % mod * inv[x-y] % mod;
}
inline void up(int &x, int y){
x = x + y >= mod ? x + y - mod : x + y;
}
int main(){
js[0] = 1, inv[0] = 1;
for(int i = 1; i < N; i++){
js[i] = 1ll * js[i-1] * i % mod;
inv[i] = Pow(js[i], mod - 2);
}
read(n), read(k);
int ans = 0;
for(int i = n; i <= k; i++)
if(i > k - i){
int t = i - n + 1;
up(ans, C(k, i));
up(ans, mod - C(k, i + t));
}
else if(i == k - i){
int t = i - n + 1;
up(ans, C(k - 1, i));
up(ans, mod - C(k - 1 , i + t));
}
cout << ans << endl;
return 0;
}