fft练习
数学相关一直都好弱啊>_<
窝这个月要补一补数学啦, 先从基础的fft补起吧!
现在做了 道。
窝的fft 模板 (bzoj 2179)
1 #include <iostream> 2 #include <cstdio> 3 #include <cstring> 4 #include <cmath> 5 #include <algorithm> 6 #define MAXN 200005 7 #define PI M_PI 8 using namespace std; 9 struct CP{ 10 double x, y; 11 CP(){} 12 CP(double x, double y) : x(x), y(y) {} 13 inline CP operator+ (CP b) {return CP(x+b.x, y+b.y);} 14 inline CP operator- (CP b) {return CP(x-b.x, y-b.y);} 15 inline CP operator* (CP b) {return CP(x*b.x-y*b.y, x*b.y+y*b.x);} 16 }a[MAXN], b[MAXN], A[MAXN], x, y; 17 int n, N, l, dig[MAXN], ans[MAXN]; 18 char s1[MAXN], s2[MAXN]; 19 inline void dft(CP *a, int n, int f){ 20 for(int i = 0; i < n; i ++) A[i] = a[dig[i]]; 21 for(int i = 0; i < n; i ++) a[i] = A[i]; 22 for(int i = 2; i <= n; i <<= 1){ 23 CP wn(cos(2*PI/i), f*sin(2*PI/i)); 24 for(int k = 0; k < n; k += i){ 25 CP w(1, 0); 26 for(int j = 0; j < i/2; j ++) x = a[k+j], y = w*a[k+j+i/2], a[k+j]=x+y, a[k+j+i/2] = x-y, w = w*wn; 27 } 28 } 29 } 30 int main(){ 31 scanf("%d%s%s", &n, s1, s2); 32 for(int i = 0; i < n; i ++) a[i].x = s1[n-i-1]-'0', b[i].x = s2[n-i-1]-'0'; 33 for(N = 1, l = 0; N < n; N <<= 1, l ++); N <<= 1, l ++; 34 for(int i = 0; i < N; i ++){ 35 int ret = 0, p = i; 36 for(int j = 1; j <= l; j ++) ret <<= 1, ret += (p&1), p >>= 1; 37 dig[i] = ret; 38 } 39 dft(a, N, 1), dft(b, N, 1); 40 for(int i = 0; i < N; i ++) a[i] = a[i]*b[i]; 41 dft(a, N, -1); 42 for(int i = 0; i < N; i ++) ans[i] = (int)(a[i].x / N + 0.5); 43 for(int i = 0; i < N; i ++) ans[i+1] += ans[i]/10, ans[i] %= 10; 44 l = N; while(l > 1 && ans[l-1] == 0) l --; 45 for(int i = l-1; i >= 0; i --) printf("%d", ans[i]); cout << endl; 46 return 0; 47 }
BZOJ 2194
之前做过的啦,求 c_k = Σ(a_i * b_(i-k)) (k <= i)
标准的差积形式是 i + j 一定, 现在的式子要求 i - j 一定, 容易想到把所有b的下标都乘以-1原式即转化为i+j一定了, 为了方便处理可以把所有b再加上n使得下标为非负数,即把b_i 变为 b_(n-i) 。
tips: 卷积的形式并不一定只是 k = i+j, 还有可能是 k = i-j 或者 …………
bzoj 3527: [Zjoi2014]力
之前做过的啦。 不打式子了, 是个人都会把qi乘进去, 然后原式就变成了求 E_i = Σqj/((i-j)^2) (j<i) - Σqj/((i-j)^2) (j>i)
我第一次做的时候并没有看出来这是个卷积QAQ 蠢死了QAQ
我们可以设 一个 函数 f(x) = q_x, g(x) = 1/(x^2), 然后那个式子的前一项不就是 f 和 g 的卷积了吗! 第二个式子是 当 i-j一定时候的情况, 那不就是上一题 bzoj 2194了吗,,,,然后就没了。
tips:不要拘泥于已有的数列,可以根据题目自己构造出新的函数, 需要求 F_i = Σ( f(j) 乘或除以 g(i±j) ) 的时候很有可能是卷积!!!看到 j 和 i±j 这类的关键词就应该往这边多想一想。
BZOJ 3771 Triple
之前做过的啦。给定n个物品,可以用一个/两个/三个不同的物品凑出不同的价值,求每种价值有多少种拼凑方案(顺序不同算一种) (题目大意来自popoqqq)
自己还是没有想到QAQ
这个要构建一个母函数。初步地感受母函数的应用:如果我们可以快速地计算出一次操作后每种效果有多少种产生方案(0/1),并且多次操作的效果是可以累加的,那么在用fft优化以后(效果的累加不就是卷积吗)可以比较快速地(k*nlogn)计算出k次操作后每种效果有多少种产生方案 (怎么觉得我在说的是用快速幂优化的矩阵乘法啊。。。。。)感觉的确和矩乘有点像, 不过矩乘的每次操作的状态是一个二维的数组, 因为在不同的点可以进行的操作不同, 但是 母函数只需要求出一个一维数组, 因为在应用母函数的时候当前状态 并不会影响下一步的操作
(以上皆是窝自己的口胡,等我读完具体数学的 "生成函数" 一章后再来总结吧QAQ)
tips: 见上
BZOJ 3160 万径人踪灭
给出一个只含有a和b的字符串, 求有多少个至少有一个断点(即不完全连续)的回文子序列,要求子序列选出来的每一个的位置同样必须关于任意一个位置或中缝对称。
考虑以每一个点i为中间的点所组成的符合要求的子序列,对于每一个j, 如果 s[i-j] == s[i+j] 则可以把这对字符加到子序列里, 当然也可以不加, 显然对于不同的j, 每一对字符是相互独立的, 所以以i为中点的贡献就是 2^(s[i-j]==s[i+j]的数量), 注意因为子序列不可以完全连续所以还要再减去一个可以直接manachar求出来的完全连续的回文串数量。 所以现在唯一剩下的问题就是对于每一个i, Σ (s[i-j] == s[i+j]), 因为只有a, b两种字符, 显然可以拆开来算, 设数组A使得 A_i = (s[i] == 'a'), 则字符a对于以i为中心贡献就是 Σ (s[j]*s[2*i-j]) , 这个显然就是一个卷积啦。
1 #include <iostream> 2 #include <cstdio> 3 #include <cmath> 4 #include <algorithm> 5 #include <cstring> 6 #define ll long long 7 #define mod 1000000007 8 #define MAXN 1000005 9 #define PI M_PI 10 using namespace std; 11 char s[MAXN]; 12 int N, L, len, r[MAXN], rev[MAXN]; 13 ll ans = mod; 14 ll mypow(ll x, int k){ 15 ll ret = 1; 16 while(k){ 17 if(k&1)(ret *= x) %= mod; 18 (x*=x)%=mod; k >>= 1; 19 }return ret; 20 } 21 struct CP{ 22 double x, y; 23 CP () {} 24 CP (double x, double y) : x(x), y(y) {} 25 inline CP operator+ (CP b){return CP(x+b.x, y+b.y);} 26 inline CP operator- (CP b){return CP(x-b.x, y-b.y);} 27 inline CP operator* (CP b){return CP(x*b.x-y*b.y, x*b.y+y*b.x);} 28 }A[MAXN], B[MAXN], T[MAXN], x, y; 29 void dft(CP *a, int n, int f){ 30 for(int i = 0; i < n; i ++) T[i] = a[rev[i]]; 31 for(int i = 0; i < n; i ++) a[i] = T[i]; 32 for(int i = 2; i <= n; i <<= 1){ 33 CP wn(cos(PI*2/i), f*sin(PI*2/i)); 34 for(int j = 0; j < n; j += i){ 35 CP w(1, 0); 36 for(int k = 0; k < i/2; k ++) 37 x = a[j+k], y = w*a[j+k+i/2], a[j+k] = x+y, a[j+k+i/2] = x-y, w = w*wn; 38 } 39 } 40 } 41 int main(){ 42 scanf("%s", s + 1); 43 len = strlen(s+1); 44 for(int i = len; i >= 1; i --) s[i+i] = s[i], s[i+i-1] = '*'; 45 s[0] = '&'; s[len+len+1] = '*'; s[len+len+2] = '#'; 46 int k = 0; r[0] = r[1] = 1; 47 for(int i = 2; i <= len+len; i ++){ 48 r[i] = max(min(k+r[k]-i, r[k-(i-k)]), 1); 49 while(i-r[i]>0 && s[i+r[i]] == s[i-r[i]]) r[i] ++; 50 if(i+r[i] > k+r[k]) k = i; 51 // printf("!!! %d %d\n", i, r[i]/2); 52 ans -= r[i]/2; 53 if(ans < 0) ans += mod; 54 } 55 // cout << ans << endl; 56 int n = len+len; 57 for(int i = 0; i <= n; i ++) A[i].x = (s[i]=='a'), B[i].x = (s[i]=='b'); 58 for(N=1, L=0; N <= n; N<<=1, L ++); N <<= 1, L ++; 59 for(int i = 0; i < N; i ++){ 60 int ret = 0, p = i; 61 for(int j = 1; j <= L; j ++) ret <<= 1, ret += (p&1), p >>= 1; 62 rev[i] = ret; 63 } 64 dft(A, N, 1); dft(B, N, 1); 65 for(int i = 0; i < N; i ++) A[i] = A[i]*A[i] + B[i]*B[i]; 66 dft(A, N, -1); 67 for(int i = 1; i < N; i ++){ 68 (ans += mypow(2, ((ll)(A[i].x/N+0.5)+1)>>1)-1) %= mod; 69 // printf("%d %I64d\n", i, (ll)(A[i].x/N+0.5)); 70 } 71 cout << ans << endl; 72 //system("pause"); 73 return 0; 74 }
hdu 4509 3-idiot
给你三根线段, 问你有多少种可以组成三角形的方法。
做过bzoj3771后这题就是一眼题了QAQ 类似地,对于边长建出一个生成函数然后卷积一次算出由两条边组成的长度和为x的线段组有多少组然后再扫一遍第三条边就可以了。
BZOJ 3992 Sdoi2015 序列统计
给你一个包含了S个数的模M意义下的集合,问只用集合中的数组成一个长度为N的序列使得序列中所有数的乘积为X的方案数有多少种。
模意义下的乘积显然可以求一个原根搞成指标的形式,于是就变成一堆数的和的形式了, 直接上裸的FFT就好, 因为这道题要求取模且模数又是一个X*2^p+1形式的数所以把FFT写成数论版的就可以了。
BZOJ 3456 城市规划
编号是很优美的3456哦~
求出n个点的简单(无重边无自环)无向连通图数目.
递推式就是串珠子那道题嘛,枚举最小点所在连通块大小然后算出不连通的方案数再容斥一下就好了 fi = 2C(2,i) - ∑(j = 1~(i-1)) fj* C(j-1, i-1) * 2C(2,i-j)
把C(j-1,i-1) 拆了然后把 (i-1)!这一项提出来, 显然就是一个卷积了。
因为这里每一步的卷积会用到之前的结果, 1~i-1中的每一个结果都会对fi产生贡献, 用cdq分治的思想搞一下就可以了。work(l,r)计算出l~r中每一个数的f值。先调用work(l,mid)得到l~mid中每一个f值,然后算出 l~mid中所有数对 mid+1~r的贡献, 最后再调用 work(mid+1,r)就可以了。其中 l~mid 对 mid+1到r的贡献做一次多项式乘法就可以一起求出来啦!
分治 + fft , O(nlog2n)。 很好理解代码也很好写。
1 #include <iostream> 2 #include <cstdio> 3 #include <cstring> 4 #include <cmath> 5 #include <algorithm> 6 #define mod 1004535809 7 #define MAXN 300005 8 #define G 3 9 #define ll long long 10 using namespace std; 11 int n, inv2[MAXN], inv[MAXN], fac[MAXN], infac[MAXN], T[MAXN], f[MAXN], A[MAXN], B[MAXN], tmp[MAXN], rev[MAXN]; 12 int mypow(int x, int k){ 13 int ans = 1; 14 while(k){ 15 if(k&1) ans = (ll)ans*x%mod; 16 x = (ll)x*x%mod; k >>= 1; 17 } return ans; 18 } 19 void dft(int *a, int n, int f){ 20 for(int i = 0; i < n; i ++) tmp[i] = a[rev[i]]; 21 for(int i = 0; i < n; i ++) a[i] = tmp[i]; 22 for(int i = 2; i <= n; i <<= 1){ 23 ll w = mypow(G, ((ll)f*(mod-1)/i)%(mod-1)); 24 for(int j = 0; j < n; j += i){ 25 ll wn = 1; 26 for(int k = 0; k < i/2; k ++){ 27 ll x = a[j+k], y = (wn*a[j+k+i/2])%mod; 28 a[j+k] = (x+y)%mod, a[j+k+i/2] = (x-y+mod)%mod; (wn*=w)%=mod; 29 } 30 } 31 } 32 } 33 void solve(int l, int r){ 34 if(l == r){ 35 f[l] = (T[l] - (ll) fac[l-1] * f[l] % mod + mod) % mod; return; 36 } 37 int mid = l + r >> 1, nn = max(mid-l+1, r-mid); 38 solve(l, mid); 39 int N, L; 40 for(N=1,L=0; N<=nn; N<<=1, L++); N<<=1, L++; 41 for(int i = 0; i < N; i ++){ 42 int ret = 0, p = i; 43 for(int j = 1; j <= L; j ++) ret <<= 1, ret += (p&1), p >>= 1; 44 rev[i] = ret; 45 } 46 for(int i = 0; i < N; i ++) A[i] = B[i] = 0; 47 for(int i = l; i <= mid; i ++) A[i-l] = (ll) f[i] * infac[i-1] % mod; 48 for(int i = 1; i <= r-l; i ++) B[i] = (ll) T[i] * infac[i] % mod; 49 dft(A, N, 1); dft(B, N, 1); 50 for(int i = 0; i < N; i ++) A[i] = (ll)A[i]*B[i] % mod; 51 dft(A, N, mod-2); 52 for(int i = mid+1; i <= r; i ++) (f[i] += (ll)A[i-l] * inv[N] % mod) %= mod; 53 solve(mid+1, r); 54 } 55 int main(){ 56 scanf("%d", &n); 57 int N, l; 58 for(N=1,l=0; N<=n; N<<=1, l++); N<<=1, l ++; 59 inv2[0] = 1, inv2[1] = mypow(2, mod-2); 60 for(int i = 2; i <= N; i ++) inv2[i] = (ll)inv2[i-1]*inv2[1] % mod; 61 for(int i = 1; i <= N; i ++) inv[i] = mypow(i, mod-2); 62 fac[0] = infac[0] = 1; 63 for(int i = 1; i <= N; i ++) fac[i] = (ll)fac[i-1]*i%mod, infac[i] = (ll)infac[i-1]*inv[i]%mod; 64 for(int i = 1; i <= N; i ++) T[i] = mypow(2, (((ll)i*(i-1))/2)%(mod-1)); 65 solve(1, n); 66 printf("%d\n", f[n]); 67 return 0; 68 }
当然这道题有更简单的 O(nlogn) 的多项式求逆元的做法。
codeforces #250E The Child and Binary Tree
Picks博客上面讲的多项式开根练习题。
bzoj 上面卡不过去QAQ
1 #include <iostream> 2 #include <cstdio> 3 #include <cmath> 4 #include <algorithm> 5 #include <cstring> 6 #define ll long long 7 #define MAXN 500005 8 #define mod 998244353 9 #define G 3 10 #define inv_2 499122177 11 using namespace std; 12 int N, L, n, m, c[MAXN], d[MAXN], rev[MAXN], T[MAXN]; 13 int mypow(int x, int k){ 14 int ret = 1; 15 while(k){ 16 if(k&1) ret = ((ll)ret*x)%mod; 17 x = (ll)x*x%mod; k >>= 1; 18 }return ret; 19 } 20 void dft(int *a, int n, int f){ 21 for(int i = 0; i < n; i ++) T[i] = a[rev[i]]; 22 for(int i = 0; i < n; i ++) a[i] = T[i]; 23 for(int i = 2; i <= n; i <<= 1){ 24 ll w = mypow(G, ((ll)f*(mod-1)/i) % (mod-1)); 25 for(int j = 0; j < n; j += i){ 26 ll wn = 1; 27 for(int k = 0; k < (i>>1); k ++){ 28 ll x = a[j+k], y = wn * a[j+k+(i>>1)] % mod; 29 a[j+k] = (x+y)%mod, a[j+k+(i>>1)] = (x-y+mod)%mod, (wn *= w) %= mod; 30 } 31 } 32 } 33 } 34 void getinv(int *a, int *b, int n){ 35 static int tmp[MAXN]; 36 if(n == 1) {b[0] = mypow(a[0], mod-2); return;} 37 getinv(a, b, n>>1); 38 memcpy(tmp, a, n*4); 39 memset(tmp+n, 0, n*4); 40 41 int l = 0, tt = n<<1, NN = n<<1; 42 while(tt) tt >>= 1, l ++; l --; 43 for(int i = 0; i < NN; i ++){ 44 int la = 0, p = i; 45 for(int j = 0; j < l; j ++) la <<= 1, la += (p&1), p >>= 1; 46 rev[i] = la; 47 } 48 49 dft(tmp, NN, 1); 50 dft(b, NN, 1); 51 for(int i = 0; i < NN; i ++) 52 tmp[i] = (ll)b[i] * (2-(ll)tmp[i]*b[i]%mod + mod)%mod; 53 dft(tmp, NN, mod-2); 54 ll inv = mypow(NN, mod-2); 55 for(int i = 0; i < n; i ++) 56 b[i] = tmp[i] * inv % mod; 57 memset(b+n, 0, n*4); 58 } 59 void getsqrt(int *a, int *b, int n){ 60 static int tmp[MAXN], b_in[MAXN]; 61 if(n == 1) {b[0] = 1; return;} 62 getsqrt(a, b, n>>1); 63 memset(b_in, 0, n*4); 64 getinv(b, b_in, n); 65 memcpy(tmp, a, n*4); 66 memset(tmp+n, 0, n*4); 67 68 int l = 0, tt = n<<1, NN = n<<1; 69 while(tt) tt >>= 1, l ++; l --; 70 for(int i = 0; i < NN; i ++){ 71 int la = 0, p = i; 72 for(int j = 0; j < l; j ++) la <<= 1, la += (p&1), p >>= 1; 73 rev[i] = la; 74 } 75 76 dft(tmp, NN, 1); 77 dft(b, NN, 1); 78 dft(b_in, NN, 1); 79 for(int i = 0; i < NN; i ++) 80 tmp[i] = ((ll) inv_2 * (b[i] + (ll)tmp[i]*b_in[i] % mod)) % mod; 81 dft(tmp, NN, mod-2); 82 ll inv = mypow(NN, mod-2); 83 for(int i = 0; i < n; i ++) 84 b[i] = tmp[i]*inv%mod; 85 memset(b+n, 0, n*4); 86 } 87 int main(){ 88 scanf("%d%d", &n, &m); 89 for(N=1,L=0; N<=m; N<<=1, L++); 90 for(int i = 1; i <= n; i ++){ 91 int x; scanf("%d", &x); 92 c[x] -= 4; 93 if(c[x] < 0) c[x] += mod; 94 } 95 c[0] = 1; 96 getsqrt(c, d, N); 97 static int C[MAXN], D[MAXN]; 98 memcpy(C, d, N*4); 99 (++ C[0]) %= mod; 100 getinv(C, D, N); 101 for(int i = 1; i <= m; i ++) printf("%d\n", (D[i]<<1) % mod); 102 //system("pause"); 103 return 0; 104 }