[洛谷P5075][JSOI2012]分零食
题目大意:有$m(m\leqslant10^8)$个人站成一排,有$n(n\leqslant10^4)$个糖果,若第$i$个人没有糖果,那么第$i+1$个人也没有糖果。一个人有$x$个糖果会获得快乐值$v(x)$。
$$
v(x)=
\begin{cases}
ax^2+bx+c&(x>1)\\
1&(x=1)
\end{cases}
$$
一个方案的价值为$\prod\limits_{i=1}^mv(s_i)$($s_i$为第$i$个人得到的糖果数)。问所有方案的价值和,对$mod(mod\leqslant255)$取模
题解:令$f(x)=\sum\limits_{i=1}^{\infty}v(i)x^i$,那么$k$个人全部得到糖果的方案数是$[x^n]f^k(x)$。
$$
\begin{align*}
ans&=[x^n]\sum\limits_{i=1}^mf^i(x)\\
&=[x^n]\sum\limits_{i=0}^mf^i(x)\\
&=[x^n]\dfrac{1-f^{m+1}(x)}{1-f(x)}
\end{align*}
$$
注意这里的模数不是质数,但很小,可以用一模$NTT$,注意求逆部分,需要多把点值转成系数,因为负数无法表示。
卡点:$NTT$预处理部分度数没有加,调了一个上午。。。
C++ Code:
#include <cstdio> #include <cstring> #include <algorithm> #define maxn 32768 const int mod = 998244353; namespace Math { inline int pw(int base, int p) { static int res; for (res = 1; p; p >>= 1, base = static_cast<long long> (base) * base % mod) if (p & 1) res = static_cast<long long> (res) * base % mod; return res; } inline int inv(int x) { return pw(x, mod - 2); } } inline void reduce(int &x) { x += x >> 31 & mod; } inline void clear(register int *l, const int *r){ if(l >= r) return ; while (l != r) *l++ = 0; } int n, m, a, b, c, pmod; namespace Poly { #define N maxn int lim, s, rev[N]; int Wn[N + 1]; inline void init(const int n) { lim = 1, s = -1; while (lim < n) lim <<= 1, ++s; for (register int i = 1; i < lim; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << s; const int t = Math::pw(3, (mod - 1) / lim); *Wn = 1; for (register int *i = Wn; i != Wn + lim; ++i) *(i + 1) = static_cast<long long> (*i) * t % mod; } inline void NTT(int *A, const int op = 1) { for (register int i = 1; i < lim; ++i) if (i < rev[i]) std::swap(A[i], A[rev[i]]); for (register int mid = 1; mid < lim; mid <<= 1) { const int t = lim / mid >> 1; for (register int i = 0; i < lim; i += mid << 1) for (register int j = 0; j < mid; ++j) { const int X = A[i + j], Y = static_cast<long long> (A[i + j + mid]) * Wn[j * t] % mod; reduce(A[i + j] += Y - mod), reduce(A[i + j + mid] = X - Y); } } if (!op) { const int ilim = Math::inv(lim); for (register int *i = A; i != A + lim; ++i) *i = static_cast<long long> (*i) * ilim % mod; std::reverse(A + 1, A + lim); } } inline void INV(int *A, int *B, int n) { if (n == 1) { *B = 1; return ; } static int C[N]; const int len = n + 1 >> 1; INV(A, B, len); init(n + n - 1); std::copy(A, A + n, C), clear(C + n, C + lim); NTT(C), NTT(B); for (register int i = 0; i < lim; ++i) C[i] = static_cast<long long> (C[i]) * B[i] % mod; NTT(C, 0), clear(C + n, C + lim); for (int *i = C; i != C + n; ++i) *i = pmod - *i % pmod; C[0] += 2, NTT(C); for (int i = 0; i < lim; ++i) B[i] = static_cast<long long> (B[i]) * C[i] % mod; NTT(B, 0); for (int *i = B; i != B + n; ++i) *i %= pmod; clear(B + n, B + lim); } inline void PW(int *A, int *B, int n, int p) { static int C[N], D[N]; std::copy(A, A + n, C); init(n + n - 1); B[0] = 1, clear(B + 1, B + lim); while (p) { if (p & 1) { std::copy(C, C + n, D), clear(D + n, D + lim); NTT(B), NTT(D); for (int i = 0; i < lim; ++i) B[i] = static_cast<long long> (B[i]) * D[i] % mod; NTT(B, 0), clear(B + n, B + lim); for (int *i = B; i != B + n; ++i) *i %= pmod; } if (p >>= 1) { NTT(C); for (int *i = C; i != C + lim; ++i) *i = static_cast<long long> (*i) * *i % mod; NTT(C, 0), clear(C + n, C + lim); for (int *i = C; i != C + n; ++i) *i %= pmod; } } } #undef N } int f[maxn], A[maxn], B[maxn]; int main() { scanf("%d%d", &n, &pmod); ++n; scanf("%d%d%d%d", &m, &a, &b, &c); m = std::min(m, n - 1); for (int i = 1; i < n; ++i) f[i] = (i * i % pmod * a + i * b + c) % pmod; Poly::PW(f, A, n, m + 1); for (int *i = A; i != A + n; ++i) *i = pmod - *i; ++*A; for (int *i = f; i != f + n; ++i) *i = pmod - *i; ++*f; Poly::INV(f, B, n); Poly::init(n + n - 1); Poly::NTT(A), Poly::NTT(B); for (int i = Poly::lim; ~i; --i) A[i] = static_cast<long long> (A[i]) * B[i] % mod; Poly::NTT(A, 0); printf("%d\n", A[n - 1] % pmod); return 0; }