[洛谷P3321][SDOI2015]序列统计
题目大意:给你一个集合$n,m,x,S(S_i\in(0,m],m\leqslant 8000,m\in \rm{prime},n\leqslant10^9)$,求一个长度为$n$的序列$Q$,满足$Q_i\in S$,且$\prod\limits _{i=1}^nQ_i=x$,求序列的个数
题解:乘比较麻烦,可以把每个数求$\ln$,可以求出$m$的原根,求原根可以暴力$O(m^2)$求,然后每个数求$\ln$,求出生成函数$F(x)$,算出$F^n(x)$。发现$n$较大,多项式快速幂即可。
卡点:无
C++ Code:
#include <cstdio> #include <algorithm> #include <cstring> #define maxn 16384 | 3 #define maxm 8010 const int mod = 1004535809, G = 3; int n, m, x, S, g; int vis[maxm]; int get(int m) { bool find = false; for (int i = 2; i < m; i++) { memset(vis, -1, sizeof vis); int t = 1; vis[1] = 0; for (int j = 1; j < m - 1; j++) { t = t * i % m; if (~vis[t]) break; else vis[t] = j; if (j == m - 2) find = true; } if (find) return i; } return 20040826; } int lim, ilim, s, rev[maxn]; int base[maxn], ans[maxn], Wn[maxn + 1]; inline int pw(int base, int p) { int res = 1; for (; p; p >>= 1, base = 1ll * base * base % mod) if (p & 1) res = 1ll * res * base % mod; return res; } inline int Inv(int x) {return pw(x, mod - 2);} inline void init(int n) { lim = 1, s = -1; while (lim < n) lim <<= 1, s++; ilim = Inv(lim); for (int i = 0; i < lim; i++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << s; int t = pw(G, (mod - 1) / lim); Wn[0] = 1; for (int i = 1; i <= lim; i++) Wn[i] = 1ll * Wn[i - 1] * t % mod; } inline void up(int &a, int b) {if ((a += b) >= mod) a -= mod;} inline void NTT(int *A, int op) { for (int i = 0; i < lim; i++) if (i < rev[i]) std::swap(A[i], A[rev[i]]); for (int mid = 1; mid < lim; mid <<= 1) { int t = lim / mid >> 1; for (int i = 0; i < lim; i += mid << 1) { for (int j = 0; j < mid; j++) { int W = op ? Wn[j * t] : Wn[lim - j * t]; int X = A[i + j], Y = 1ll * A[i + j + mid] * W % mod; up(A[i + j], Y), up(A[i + j + mid] = X, mod - Y); } } } if (!op) for (int i = 0; i < lim; i++) A[i] = 1ll * A[i] * ilim % mod; } int C[maxn], D[maxn]; inline void MUL(int *A, int *B) { for (int i = 0; i < lim; i++) C[i] = A[i], D[i] = B[i]; NTT(C, 1), NTT(D, 1); for (int i = 0; i < lim; i++) C[i] = 1ll * C[i] * D[i] % mod; NTT(C, 0); for (int i = 0; i < lim; i++) A[i] = C[i]; for (int i = m - 1; i < lim; i++) (A[i % (m - 1)] += A[i]) %= mod, A[i] = 0; } int main() { scanf("%d%d%d%d", &n, &m, &x, &S); g = get(m); for (int i = 0, tmp; i < S; i++) { scanf("%d", &tmp); if (tmp) base[vis[tmp]] = 1; } init(m << 1); ans[0] = 1; for (; n; n >>= 1, MUL(base, base)) if (n & 1) MUL(ans, base); printf("%d\n", ans[vis[x]]); return 0; }