[组合数学][CF1188E]Problem from Red Panda
题目链接
简要题意
有一个长度为 \(k\) 的数组 \(a\),每次可以选择一个 \(1\le i\le k\),让 \(a_i\) 加上 \(k-1\),并对于所有的 \(j\ne i\) 让 \(a_j\) 减掉 \(1\),任何时候必须保证 \(a\) 数组非负。
求通过任意多次(可以为 \(0\) 次)操作,能达到的不同的 \(a\) 数组方案数膜 \(998,244,353\) 后的结果。
数据范围:\(2\le k\le 10^5\),\(a_i\ge 0\),\(\sum_{i=1}^ka_i\le10^6\),但实际上存在复杂度与 \(\sum_{i=1}^ka_i\) 无关的做法。
Solution
设 \(x_i\) 表示 \(a_i\) 被选中的次数,考虑一个 \(\{x_i\}\) 合法的条件。记 \(s=\sum_{i=1}^kx_i\)。
我们不妨把操作看作数组 \(a\) 整体减 \(1\) 之后 \(a_i+=k\)。
显然我们必须保证最后的 \(a\) 数组非负,故 \(a_i-s+kx_i\ge 0\),也就是 \(x_i\ge\lceil\frac{\max(s-a_i,0)}k\rceil\)。
在这个条件下,判断是否对于 \(0\le t<s\) 满足 \(t\) 轮操作之后 \(a\) 数组非负,只需将所有 \(a_i\) 减掉 \(t\),然后尝试用 \(t\) 个 \(k\) 来填充为负的 \(a_i\) 值,判断是否能够填充成功即可,即 \(\sum_{i=1}^k\lceil\frac{\max(t-a_i,0)}k\rceil\le t\)。显然在 \(x_i\ge\lceil\frac{\max(s-a_i,0)}k\rceil\) 的限制下,第 \(i\) 个数被操作的次数不会超过 \(x_i\)。
于是一个 \(\{x_i\}\) 合法的条件为:
(1)\(x_i\ge\lceil\frac{\max(s-a_i,0)}k\rceil\)
(2)对于所有 \(0\le t\le s\) 都有 \(\sum_{i=1}^k\lceil\frac{\max(t-a_i,0)}k\rceil\le t\)。
在从小到大枚举 \(s\) 的过程中,\(\sum_{i=1}^k\lceil\frac{\max(s-a_i,0)}k\rceil\) 容易求出,将 \(a\) 排序后用指针维护 \(a_i<s\) 的部分,对 \(a_i\bmod k\) 用个桶维护每种值的出现次数即可。
回到问题,我们不能直接对 \(\{x_i\}\) 计数,因为不同的 \(x\) 数组可能对应同一个 \(a\) 数组。
首先我们发现,如果所有的 \(x_i\) 都相等,则这样的操作对原数组没有影响。
也就是说,如果所有的 \(x_i\) 都不为 \(0\),则把所有 \(x_i\) 都减掉 \(1\) 之后会得到一个等价的方案。
同样地如果将一部分 \(x_i\)(个数在 \([1,k-1]\) 之间)减掉 \(1\),则得到的方案一定不等价。
故可以转化成对 \(\{x_i\}\) 数组计数,但 \(x\) 数组必须满足至少有一个 \(0\)。易得这时有 \(0\le s\le\max a_i\)。
从小到大枚举 \(s\),遇到 \(w=\sum_{i=1}^k\lceil\frac{\max(s-a_i,0)}k\rceil>s\) 的情况立刻 break
掉。
问题转化成 \(k\) 个变量,其中前 \(r\)(\(a_i<s\) 的 \(i\) 个数)个变量有一个取值下界 \(down_i\),满足 \(down_i\ge 1\) 且 \(w=\sum_{i=1}^rdown_i\),求为这 \(k\) 个变量取值,使得至少有一个 \(0\),并且所有变量的和为 \(s\) 的方案数。
先去掉下界 \(down\),转成所有变量的和为 \(s-w\),并且后 \(k-r\) 个变量至少有一个 \(0\)。
考虑容斥,用任意方案减掉没有 \(0\) 的方案,由插板法得方案数:
总复杂度 \(O(k\log k+\max a_i)\)。
Code
#include <bits/stdc++.h>
template <class T>
inline void read(T &res)
{
res = 0; bool bo = 0; char c;
while (((c = getchar()) < '0' || c > '9') && c != '-');
if (c == '-') bo = 1; else res = c - 48;
while ((c = getchar()) >= '0' && c <= '9')
res = (res << 3) + (res << 1) + (c - 48);
if (bo) res = ~res + 1;
}
const int N = 2e6 + 5, djq = 998244353;
int k, n, a[N], cnt[N], fac[N], inv[N], ans;
int C(int n, int m) {return 1ll * fac[n] * inv[m] % djq * inv[n - m] % djq;}
int main()
{
fac[0] = inv[0] = inv[1] = 1;
for (int i = 1; i < N; i++) fac[i] = 1ll * fac[i - 1] * i % djq;
for (int i = 2; i < N; i++) inv[i] = 1ll * (djq - djq / i) * inv[djq % i] % djq;
for (int i = 2; i < N; i++) inv[i] = 1ll * inv[i] * inv[i - 1] % djq;
read(k); int cur = 0;
for (int i = 1; i <= k; i++) read(a[i]), n += a[i];
std::sort(a + 1, a + k + 1);
for (int i = 0, j = 1; i <= a[k]; i++)
{
while (a[j] < i) cnt[a[j++] % k]++;
cur += cnt[(i - 1 + k) % k];
if (cur > i) return std::cout << ans << std::endl, 0;
ans = (ans + C(i - cur + k - 1, k - 1)) % djq;
if (i - cur + j - 2 >= k - 1)
ans = (ans - C(i - cur + j - 2, k - 1) + djq) % djq;
}
return std::cout << ans << std::endl, 0;
}