[组合数学][多项式][拉格朗日插值]count
- 源自 ditoly 大爷的 FJ 省队集训课件
Statement
-
有 \(m\) 个正整数变量,求有多少种取值方案
-
使得所有变量的和不超过 \(S\)
-
并且前 \(n\) 个变量的值都不超过 \(t\)
-
答案对 \(10^9+7\) 取模
-
\(m-n\le1000\) ,\(m\le 10^9\) ,\(t\le 10^5\) ,\(nt\le s\le 10^{18}\)
Solution
-
由于 \(nt\le s\le 10^{18}\) ,所以我们的解可以表示成:
-
\[\sum_{x_1=1}^t\sum_{x_2=1}^t\dots\sum_{x_n=1}^t\binom{S-\sum_{i=1}^nx_i}{m-n} \]
-
设 \(s=S-(m-n)+1-\sum_{i=1}^nx_i\)
-
注意到上面的和式里面的组合数是一个关于 \(s\) 的多项式,具体地,它等于 \(\frac{s^{\overline{m-n}}}{(m-n)!}\)
-
其中 \(n^{\overline m}\) 表示 \(n\) 的 \(m\) 次上升幂,即 \(\prod_{i=0}^{m-1}(n+i)\)
-
由第一类斯特林数的生成函数可得,这个多项式的 \(i\) 次项系数为 \(\frac{\begin{bmatrix}m-n\\i\end{bmatrix}}{(m-n)!}\) ,其中 \(\begin{bmatrix}n\\m\end{bmatrix}\) 为第一类斯特林数,即把 \(n\) 个元素分成 \(m\) 个圆排列的方案数
-
把组合数表示成多项式的形式之后,我们考虑如果我们求出了在前 \(n\) 个变量所有 \(t^n\) 种取值下对应的 \(s^i\) (\(0\le i\le m-n\))之和,那么问题就能很好地解决了
-
考虑 DP:\(f[i][j]\) 表示 \(x_{1\dots i}\) 所有取值下 \((S-m+n+1-\sum_{k=1}^ix_k)^j\) 的和
-
考虑如何从 \(f[i]\) 转移到 \(f[i+1]\)
-
这个转移即枚举 \(x_{i+1}\) 的取值,设 \(s=S-m+n+1-\sum_{k=1}^ix_k\)
-
可以发现:
-
\[\sum_{k=1}^t(s-k)^j=\sum_{k=0}^j(-1)^{j-k}\binom jks^ksum(j-k,t) \]
-
其中 \(sum(p,n)=\sum_{i=1}^ni^p\) ,可以利用插值 \(O(p)\) 求出
-
于是我们有了一个 DP 转移:
-
\[f[i+1][j]=\sum_{k=0}^j(-1)^{j-k}\binom jksum(j-k,t)f[i][k] \]
-
把组合数拆开:
-
\[\frac{f[i+1][j]}{j!}=\sum_{k=0}^j\frac{(-1)^{j-k}sum(j-k,t)}{(j-k)!}\frac{f[i][k]}{k!} \]
-
设多项式 \(F_i(x)=\sum_{j=0}^{m-n}f[i][j]x^j\) ,\(G(x)=\sum_{i=0}^{m-n}\frac{(-1)^isum(i,t)}{i!}x^i\)
-
那么很容易得到 \(F_n(x)=G(x)^n\) ,倍增快速幂即可
-
复杂度 \(O((m-n)^2\log n)\) 或 \(O((m-n)^2)\)
Code
#include <bits/stdc++.h>
const int N = 1010, rqy = 1e9 + 7;
typedef long long ll;
ll s;
int t, n, m, l, pw[N][N], f[N], fac[N], inv[N], invt[N], g[N], ans, S[N][N];
int qpow(int a, int b)
{
int res = 1;
while (b)
{
if (b & 1) res = 1ll * res * a % rqy;
a = 1ll * a * a % rqy;
b >>= 1;
}
return res;
}
int calc(int T)
{
int sum = 0, res = 0, al = 1;
for (int i = 1; i <= T + 2; i++) al = 1ll * al * (t - i + rqy) % rqy;
for (int i = 1; i <= T + 2; i++)
{
sum = (sum + pw[i][T]) % rqy;
if (i == t) return sum;
int delta = 1ll * al * invt[i] % rqy *
inv[i - 1] % rqy * inv[T + 2 - i] % rqy;
if (T + 2 - i & 1) delta = (rqy - delta) % rqy;
res = (1ll * delta * sum + res) % rqy;
}
return res;
}
int main()
{
std::cin >> s >> t >> n >> m;
l = m - n; s -= l - 1;
f[0] = fac[0] = inv[0] = inv[1] = 1;
for (int i = 1; i <= l + 2; i++) fac[i] = 1ll * fac[i - 1] * i % rqy;
for (int i = 2; i <= l + 2; i++)
inv[i] = 1ll * (rqy - rqy / i) * inv[rqy % i] % rqy;
for (int i = 2; i <= l + 2; i++) inv[i] = 1ll * inv[i] * inv[i - 1] % rqy;
for (int i = 1; i <= l + 2; i++)
{
pw[i][0] = 1;
for (int j = 1; j <= l; j++) pw[i][j] = 1ll * pw[i][j - 1] * i % rqy;
}
for (int i = 1; i <= l + 2; i++) invt[i] = qpow((t - i + rqy) % rqy, rqy - 2);
for (int i = 1; i <= l; i++) f[i] = s % rqy * f[i - 1] % rqy;
for (int i = 0; i <= l; i++) f[i] = 1ll * f[i] * inv[i] % rqy;
for (int i = 0; i <= l; i++)
{
g[i] = 1ll * calc(i) * inv[i] % rqy;
if (i & 1) g[i] = (rqy - g[i]) % rqy;
}
while (n)
{
if (n & 1) for (int i = l; i >= 0; i--)
{
f[i] = 1ll * f[i] * g[0] % rqy;
for (int j = 1; j <= i; j++)
f[i] = (1ll * f[i - j] * g[j] + f[i]) % rqy;
}
for (int i = l; i >= 0; i--)
{
g[i] = 1ll * g[i] * g[0] % rqy;
if (i) g[i] = (g[i] + g[i]) % rqy;
for (int j = 1; j < i; j++)
g[i] = (1ll * g[i - j] * g[j] + g[i]) % rqy;
}
n >>= 1;
}
S[0][0] = 1;
for (int i = 1; i <= l; i++)
for (int j = 1; j <= i; j++)
S[i][j] = (1ll * S[i - 1][j] * (i - 1) + S[i - 1][j - 1]) % rqy;
for (int i = 0; i <= l; i++)
ans = (1ll * f[i] * fac[i] % rqy * S[l][i] + ans) % rqy;
return std::cout << 1ll * ans * inv[l] % rqy << std::endl, 0;
}