THUPC2017 小 L 的计算题
求 $k=1,2,\cdots,n \space \space \sum\limits_{i=1}^n a_i^k$
$n \leq 2 \times 10^5$
sol:
时隔多年终于卡过去了
之前 $O(nlog^2n) + O(nlogn)$ 卡了我的 $O(nlog^2n) + O(nlog^2n)$ ,有点自闭
然后 fread + 编译优化 + 预处理单位根 + 不在 fft 里计算 rev 数组大力卡进时限
#include <bits/stdc++.h> #define LL long long #define rep(i, s, t) for (register int i = (s), i##end = (t); i <= i##end; ++i) #define dwn(i, s, t) for (register int i = (s), i##end = (t); i >= i##end; --i) using namespace std; const int Size=1<<16; char buffer[Size],*head,*tail; inline char Getchar() { if(head==tail) { int l=fread(buffer,1,Size,stdin); tail=(head=buffer)+l; } if(head==tail) return -1; return *head++; } inline int read() { int x=0,f=1;char c=Getchar(); for(;!isdigit(c);c=Getchar()) if(c=='-') f=-1; for(;isdigit(c);c=Getchar()) x=x*10+c-'0'; return x*f; } const int mod = 998244353, maxn = 800010; int a[maxn], r[maxn], lg[maxn], n, k; inline int skr(int x, int t) { int res = 1; while (t) { if (t & 1) res = 1LL * res * x % mod; x = 1LL * x * x % mod; t = t >> 1; } return res; } int wn[maxn], iwn[maxn]; void init(int n) { wn[0] = iwn[0] = 1; rep(i, 1, n-1) wn[i] = skr(3, (mod - 1) / (i << 1)); rep(i, 1, n-1) iwn[i] = skr(332748118, (mod - 1) / (i << 1)); } inline void fft_init(int n) {rep(i, 0, n - 1) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lg[n] - 1));} inline void fft(int *a, int n, int type) { rep(i, 0, n - 1) if (i < r[i]) swap(a[i], a[r[i]]); for (int i = 1; i < n; i <<= 1) { //int wn = skr(3, (mod - 1) / (i << 1)); //if (type == -1) // wn = skr(wn, mod - 2); int twn = (type == -1) ? iwn[i] : wn[i]; for (int j = 0; j < n; j += (i << 1)) { int w = 1; for (int k = 0; k < i; k++, w = 1LL * w * twn % mod) { int x = a[j + k], y = 1LL * w * a[j + k + i] % mod; a[j + k] = (x + y) % mod; a[j + k + i] = (x - y + mod) % mod; } } } if (type == -1) { int inv_n = skr(n, mod - 2); rep(i, 0, n - 1) a[i] = 1LL * a[i] * inv_n % mod; } } int A[maxn], B[maxn]; int C[maxn], D[maxn]; int mul(int *A, int *B, int len) { fft_init(len); // fft_init(len); fft(A, len, 1); // for(int i=0;i<len;i++)cout<<A[i]<<" "; // cout<<endl; fft(B, len, 1); for (int i = 0; i < len; i++) A[i] = (LL)A[i] * B[i] % mod; fft(A, len, -1); --len; while (!A[len]) --len; return len; } vector<int> poly[maxn]; int solve(int l, int r) { if (l == r) return poly[l].size() - 1; int mid = (l + r) >> 1; int ls = solve(l, mid), rs = solve(mid + 1, r); int L = 1; for (; L <= ls + rs; L <<= 1) ; for (int i = 0; i <= ls; i++) A[i] = poly[l][i]; for (int i = ls + 1; i < L; i++) A[i] = 0; for (int i = 0; i <= rs; i++) B[i] = poly[mid + 1][i]; for (int i = rs + 1; i < L; i++) B[i] = 0; poly[l].clear(); poly[mid + 1].clear(); L = mul(A, B, L); for (int i = 0; i <= L; i++) poly[l].push_back(A[i]); return L; } int g[maxn], f[maxn]; void mulfac(int *A, int *B, int len) { fft_init(len); fft(A, len, 1); fft(B, len, 1); for (int i = 0; i < len; i++) A[i] = 1LL * A[i] * B[i] % mod; fft(A, len, -1); } void cdq_fft(int *f, int *g, int l, int r) { if (l == r) { (f[l] += (1LL * l * g[l] % mod)) %= mod; return; } int mid = (l + r) >> 1; cdq_fft(f, g, l, mid); int len = 1, ls = 0, rs = 0; // for(;len <= ((r - l + mid)<<1);len <<= 1); // for(int i=0;i<len;i++)A[i] = B[i] = 0; for (int i = l; i <= mid; i++) C[ls++] = f[i]; for (int i = 1; i <= r - l; i++) D[rs++] = g[i]; for (; len <= (ls + rs - 1); len <<= 1) ; mulfac(C, D, len); for (int i = mid + 1; i <= r; i++) f[i] = (f[i] + C[i - l - 1]) % mod; for (int i = 0; i < len; i++) C[i] = D[i] = 0; cdq_fft(f, g, mid + 1, r); } int main() { //freopen("1.in","r",stdin); //freopen("1.out","w",stdout); lg[0] = -1; init(1 << 19); rep(i, 1, maxn - 1) lg[i] = lg[i >> 1] + 1; int T = read(); while (T--) { int ans = 0; n = read(); for (int i = 1; i <= n; i++) { a[i] = read(); if(a[i] >= mod) a[i] -= mod; poly[i].push_back(1); poly[i].push_back(a[i]); } solve(1, n); for (int i = 0; i < poly[1].size(); i++) g[i] = (((i & 1) ? 1 : (-1)) * poly[1][i] + mod) % mod; poly[1].clear(); cdq_fft(f, g, 0, n); for (int i = 1; i <= n; i++) ans ^= f[i]; memset(f, 0, sizeof(f)); memset(g, 0, sizeof(g)); memset(A, 0, sizeof(A)); memset(B, 0, sizeof(B)); memset(C, 0, sizeof(C)); memset(D, 0, sizeof(D)); cout << ans << endl; } }
然而这种 shabi 题为什么我能写 6K,给镘写也就 100 行,我菜的真实