【HDU 5730】Shell Necklace

http://acm.hdu.edu.cn/showproblem.php?pid=5730
分治FFT模板。
DP:\(f(i)=\sum\limits_{j=0}^{i-1}f(j)\times a(i-j)\)
递推第i位时要用到0到i-1位,cdq套FFT,考虑每一位上f的贡献即可。
时间复杂度\(O(n\log^2n)\)

#include<cmath>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;

const int N = 200003;
const int p = 313;
double Pi = acos(-1);

struct cp {
	double r, i;
	cp(double _r = 0, double _i = 0) : r(_r), i(_i) {}
	cp operator + (const cp &x) {return cp(r + x.r, i + x.i);}
	cp operator - (const cp &x) {return cp(r - x.r, i - x.i);}
	cp operator * (const cp &x) {return cp(r * x.r - i * x.i, r * x.i + i * x.r);}
} S[N];

void DFT(cp *A, int *rev, int n, int flag) {
	for (int i = 0; i < n; ++i) S[rev[i]] = A[i];
	for (int i = 0; i < n; ++i) A[i] = S[i];
	for (int len = 2; len <= n; len <<= 1) {
		int mid = len >> 1; cp wn = cp(cos(Pi / mid), sin(Pi / mid) * flag);
		for (int i = 0; i < n; i += len) {
			cp w = cp(1, 0);
			for (int j = 0; j < mid; ++j) {
				cp u = A[i + j], t = A[i + j + mid] * w;
				A[i + j] = u + t;
				A[i + j + mid] = u - t;
				w = w * wn;
			}
		}
	}
	if (flag == -1) for (int i = 0; i < n; ++i) A[i].r /= n;
}

cp A[N], B[N];

int n, a[N], f[N], rev[N];

void cdq(int l, int r) {
	if (l == r) return;
	int mid = (l + r) >> 1;
	cdq(l, mid);
	
	int len = r - l + 1, fn = 1, c0 = 0;
	while (fn < len) fn <<= 1, ++c0;
	
	for (int i = 0; i < fn; ++i) {
		int num = i, &res = rev[i]; res = 0;
		for (int j = 0; j < c0; ++j, num >>= 1)	{
			res <<= 1;
			if (num & 1) res |= 1;
		}
	}
	for (int i = l; i <= mid; ++i) A[i - l] = cp(f[i], 0);
	for (int i = mid + 1 - l; i < fn; ++i) A[i] = cp(0, 0);
	for (int i = 0; i < len; ++i) B[i] = cp(a[i], 0);
	for (int i = len; i < fn; ++i) B[i] = cp(0, 0);
	
	DFT(A, rev, fn, 1);
	DFT(B, rev, fn, 1);
	for (int i = 0; i < fn; ++i) A[i] = A[i] * B[i];
	DFT(A, rev, fn, -1);
	
	for (int i = mid + 1; i <= r; ++i) (f[i] += ((int)(A[i - l].r + 0.5))) %= p;
	cdq(mid + 1, r);
}

int main() {
	while (true) {
		scanf("%d", &n); if (!n) break;
		memset(a, 0, sizeof(a));
		memset(f, 0, sizeof(f));
		for (int i = 1; i <= n; ++i) scanf("%d", a + i), a[i] %= p;
		f[0] = 1;
		cdq(0, n);
		printf("%d\n", f[n]);
	}
	return 0;
}
posted @ 2017-04-05 20:15  abclzr  阅读(210)  评论(0编辑  收藏  举报