CS Academy Sliding Product Sum(组合数)
题意
有一个长为 \(N\) 的序列 \(A = [1, 2, 3, \dots, N]\) ,求所有长度 \(\le K\) 的子串权值积的和,对于 \(M\) 取模。
\(N \le 10^{18}, K \le \min(600, n), M \le 10^{18}\)
题解
一道还有些意思的组合数学题 qwq
一开始觉得这不是 \(K\) 次多项式么,直接插值QAQ 发现模数不给,逆元可能都没有,太不友好啦。
令 \(ans_k\) 为长度为 \(k\) 的子串的贡献和。其实我们就是求对于所有 \(k \le K\) 的 \(ans_k\) 的和。
先推推式子。
那么我们最后就是求对于所有 \(k \le K\) 的 \(k!\) 和 \(\displaystyle {n + 1 \choose k +1}\) 。
前者很好求,对于后者么。。。组合数,\(n\) 好大。。\(Lucas\) ?\(m\) 也好大。。。弃疗
但我们会发现 \(k\) 其实不是很大QAQ
我们需要知道有这样一个东西
为什么呢?思考一下组合意义就很明显啦。
当 \(n = m\) 的时候就有
有了这个就很好做啦~
我们维护一个序列 \(A_i\) 为 \(\displaystyle [{i \choose 0}, {i \choose 1}, \dots, {i \choose K + 1}]\) 。最后我们要求的就是 \(A_{N + 1}\) 。
那么有前面那个式子,我们就可以倍增求出 \(A_n\) 啦。
所以复杂度是 \(O(k^2 \log n)\) 的。(默认不适用慢速乘,用 __int128
)
前面那个是卷积的形式,也可以用 \(FFT\) 优化到 \(O(k \log k \log n)\) ,但由于模数很鬼畜,似乎没有那么优秀。
代码
#include <bits/stdc++.h>
#define For(i, l, r) for (register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for (register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Rep(i, r) for (register int i = (0), i##end = (int)(r); i < i##end; ++i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << (x) << endl
using namespace std;
using ll = __int128;
template<typename T> inline bool chkmin(T &a, T b) { return b < a ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, T b) { return b > a ? a = b, 1 : 0; }
template<typename T>
inline T read() {
T x(0), sgn(1); char ch(getchar());
for (; !isdigit(ch); ch = getchar()) if (ch == '-') sgn = -1;
for (; isdigit(ch); ch = getchar()) x = (x * 10) + (ch ^ 48);
return x * sgn;
}
void File() {
#ifdef zjp_shadow
freopen ("sliding-product-sum.in", "r", stdin);
freopen ("sliding-product-sum.out", "w", stdout);
#endif
}
const int N = 610;
ll n, Mod; int k;
struct Array {
ll a[N];
Array() { Set(a, 0); }
inline Array friend operator * (const Array &lhs, const Array &rhs) {
Array res;
For (i, 0, k) For (j, 0, k - i)
res.a[i + j] = (res.a[i + j] + lhs.a[i] * rhs.a[j]) % Mod;
return res;
}
};
ll fac[N];
Array fpm(Array x, ll power) {
Array res = x; -- power;
for (; power; power >>= 1, x = x * x)
if (power & 1) res = res * x;
return res;
}
int main() {
File();
n = read<ll>() + 1;
k = read<int>() + 1;
Mod = read<ll>();
Array base; base.a[0] = base.a[1] = 1;
Array prod = fpm(base, n);
fac[0] = 1;
For (i, 1, k) fac[i] = fac[i - 1] * i % Mod;
ll ans = 0;
For (i, 2, k)
ans = (ans + fac[i - 1] * prod.a[i]) % Mod;
printf ("%lld\n", (long long)(ans));
return 0;
}