2022牛客国庆集训派对day6 J(组合数+DP)
2022牛客国庆集训派对day6 J(组合数+DP)
J-Just Jump_2022牛客国庆集训派对day6 (nowcoder.com)
题目
要从0走到L,每步至少走 \(d\) ,且有 \(m\) 个攻击,当攻击到达时我们不能站在攻击点上。
求有多少可行方案能从0到L而不被打到。
思路
先求出总方案数,再dp求出非法方案并减去。
求总方案很简单,不多赘述。
f[0] = s[0] = 1;
for(int i = 1;i <= L;i ++) {
if(i - d >= 0) f[i] = s[i - d];
s[i] = (1ll * s[i - 1] + f[i]) % mod;
}
如何求非法方案,如果不考虑至少走 \(d\) 的限制,那么到位置 \(p\) 恰好走 \(t\) 步可以转化为 “有n个球,分k堆,允许出现空堆”这样的插隔板法问题。
进一步的,加入走 \(d\) 步的限制,就是先在每堆铺 \(d\) 个小球的隔板法。
此时我们算出了“从0到 \(p\) 跳 \(t\) 步的方案数”。当攻击次数大于1时,这些非法方案之间会出现重叠的情况。此时可以 按第一次被第 \(i\) 个攻击打到的方案数 对非法方案做 不重不漏的划分 并设为 \(g[i]\)。我们按 \(p_i\) 升序处理这些攻击。当我们求 \(g[i]\) 时,先求出从0到 \(p\) 跳 \(t\) 步的方案数,再依次减去第一次被 \(j|j<i\) 打到的方案数产生的贡献就能算出 \(g[i]\)。
\(g[j]\) 的单点贡献是:从0到 \(j\) 的方案数 \(g[j]\) * 从 \(j\) 到 \(i\) 走 \(t[i]-t[j]\) 的方案数。对于后者,又是一个类似的插隔板。
求出 \(g[i]\) 后,用总方案减去非法方案。对第一次被第 \(i\) 攻击打到的贡献是:g[i] * 从 \(p_i\) 到 \(L\) 的方案数 \(\to\) \(g[i]*f[L - p_i]\) 。
有一个坑点,计算组合数时不能这样写:
int C(int n,int m) {
if(n < m or n < 0 or m < 0) return 0;
return 1ll * fac[n] * invfac[m] % mod * invfac[n - m] % mod;
}
传入的组合数参数 \(n,m\) 可能爆 \(int\) 。形参改成 long long
#include <bits/stdc++.h>
#define ll long long
using namespace std;
using PII = pair<int,int>;
const int N = 1e7 + 5, M = 3010, mod = 998244353;
int f[N],g[M],s[N],L,m,d;
PII atk[M];
int fac[N],invfac[N];
int ksm(int a,int b) {
int ans = 1;
while(b) {
if(b & 1) ans = 1ll * ans * a % mod;
a = 1ll * a * a % mod;
b >>= 1;
}
return ans;
}
void init() {
fac[0] = invfac[0] = 1;
for(int i = 1;i < N;i ++) fac[i] = 1ll * fac[i - 1] * i % mod;
invfac[N - 1] = ksm(fac[N - 1], mod - 2);
for(int i = N - 2;i > 0;i --) invfac[i] = 1ll * invfac[i + 1] * (i + 1) % mod;
}
int C(ll n,ll m) {
if(n < m or n < 0 or m < 0) return 0;
return 1ll * fac[n] * invfac[m] % mod * invfac[n - m] % mod;
}
int main() {
scanf("%d%d%d", &L, &d, &m);
for(int i = 1;i <= m;i ++) scanf("%d%d", &atk[i].first, &atk[i].second);
f[0] = s[0] = 1;
for(int i = 1;i <= L;i ++) {
if(i - d >= 0) f[i] = s[i - d];
s[i] = (1ll * s[i - 1] + f[i]) % mod;
}
sort(atk + 1, atk + 1 + m);
int ans = f[L];
for(int i = 1;i <= m;i ++) {
int ti = atk[i].first, pi = atk[i].second;
g[i] = C(pi - 1ll * d * ti + ti - 1, ti - 1);
for(int j = 1;j < i;j ++) {
int dt = ti - atk[j].first, dp = pi - atk[j].second;
g[i] = ((g[i] - 1ll * g[j] * C(dp - 1ll * d * dt + dt - 1, dt - 1) % mod) % mod + mod) % mod;
}
ans = ((ans - 1ll * g[i] * f[L - pi] % mod) % mod + mod) % mod;
}
printf("%d", ans);
return 0;
}