Loading

题解-P5383 普通多项式转下降幂多项式[*easy]

看题解里都是多点求值,但是这不是显然可以分治 \(\tt FFT\) 吗 /yiw

考虑把下降幂分成两半,\(F(x) = G(x) = \sum\limits_{i = 0}^{mid} b_i x^{\underline{i}} + x^{\underline{mid+1}} \sum\limits_{i = mid + 1}^{n - 1} b_i (x - mid)^{\underline{i - mid - 1}}\)

发现左边那部分的次数是 \(mid\) 次,右边那部分是 \(mid+1\) 次多项式 \(x^{\underline{mid+1}}\) 的倍数。

左边那部分相当于把 \(F(x)\)\(x^{\underline{mid + 1}}\) 取模,右边那部分相当于 \(F(x)\)\(x^{\underline{mid + 1}}\) 做除法。

于是做一遍多项式除法,就变成要算 \(\sum\limits_{i = 0}^{mid} b_i x^{\underline{i}}\)\(\sum\limits_{i = mid + 1}^{n - 1} b_i (x - mid)^{\underline{i - mid - 1}}\)

这个东西分治下去算就行了,时间复杂度 \(\Theta(n \log ^ 2 n)\)

每次分治只需要一次取模,快于多点求值。

代码:

#include<bits/stdc++.h>
#define L(i, j, k) for(int i = j, i##E = k; i <= i##E; i++)
#define R(i, j, k) for(int i = j, i##E = k; i >= i##E; i--)
#define ll long long
#define ull unsigned long long
#define db double
#define pii pair<int, int>
#define mkp make_pair
using namespace std;
char buf[512], *p1 = buf, *p2 = buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,512,stdin),p1==p2)?EOF:*p1++)
inline int read() {
	int x = 0; char ch = getchar();
	while(!isdigit(ch)) ch = getchar();
	while(isdigit(ch)) x = x * 10 + (ch ^ 48), ch = getchar();
	return x;
}
void print(int x) {
    if(x > 9) print(x / 10);
    putchar(x % 10 + '0');
}
const int N = (1 << 21), mod = 998244353, G = 3, iG = (mod + 1) / G;
int qpow(int x, int y = mod - 2) {
	int res = 1;
	for(; y; x = (ll) x * x % mod, y >>= 1) if(y & 1) res = (ll) res * x % mod;
	return res;
}
int *A[N], tmp[N * 40], *Id = tmp; 
int Lim, lim, pp[N], PowG[N], iPowG[N];
void revlim() { L(i, 0, lim - 1) pp[i] = ((pp[i >> 1] >> 1) | ((i & 1) * (lim >> 1))); }
void up(int x) { lim = 1; for(; lim <= x; lim <<= 1); }
void cle(int *f) { L(i, 0, lim - 1) f[i] = 0; }
void init(int x) {
	int Pw;
	up(x), Lim = lim;
	Pw = qpow(G, (mod - 1) / Lim), PowG[0] = 1;
	L(i, 1, lim - 1) PowG[i] = (ll) PowG[i - 1] * Pw % mod;
	Pw = qpow(iG, (mod - 1) / Lim), iPowG[0] = 1;
	L(i, 1, lim - 1) iPowG[i] = (ll) iPowG[i - 1] * Pw % mod;
}
void NTT(int *f, int flag) {
	L(i, 0, lim - 1) if(pp[i] < i) swap(f[pp[i]], f[i]);
	for(int i = 2; i <= lim; i <<= 1) 
		for(int j = 0, l = (i >> 1), ch = Lim / i; j < lim; j += i) 
			for(int k = j, now = 0; k < j + l; k ++) {
				int pa = f[k], pb = (ll) f[k + l] * (flag == 1 ? PowG[now] : iPowG[now]) % mod;
				f[k] = (pa + pb) % mod, f[k + l] = (pa + mod - pb) % mod, now += ch;
			}
	if(flag == -1) {
		int nylim = qpow(lim);
		L(i, 0, lim - 1) f[i] = (ll) f[i] * nylim % mod;
	}
}
int sav[N], sv[N], svA[N], svB[N];
void inv(int *f, int *g, int len) { 
	if(len == 1) return g[0] = qpow(f[0]), void();
	inv(f, g, (len + 1) >> 1), up(len << 1), cle(sav), copy(f, f + len, sav), revlim(), NTT(sav, 1), NTT(g, 1);
	L(i, 0, lim - 1) g[i] = (ll) g[i] * (2ll + mod - (ll) g[i] * sav[i] % mod) % mod;
	NTT(g, -1), fill(g + len, g + lim, 0);
}
void Mul(int *f, int *g, int *ans, int n, int m) {
	static int A[N], B[N];
	up(n + m), revlim(), cle(A), cle(B), copy(f, f + n, A), copy(g, g + m, B);
	NTT(A, 1), NTT(B, 1);
	L(i, 0, lim - 1) A[i] = (ll) A[i] * B[i] % mod;
	NTT(A, -1), copy(A, A + n + m - 1, ans);
}
void div(int *f, int *g, int *ansa, int *ansb, int n, int m) {
	static int A[N];
	reverse(f, f + n), reverse(g, g + m), up((n - m + 1) << 1), cle(A);
	inv(g, A, n - m + 1), Mul(f, A, A, n - m + 1, n - m + 1);
	reverse(A, A + n - m + 1), copy(A, A + n - m + 1, ansa);
	reverse(f, f + n), reverse(g, g + m), Mul(g, A, A, m, n - m + 1);
	L(i, 0, m - 2) ansb[i] = (f[i] - A[i] + mod) % mod;
}
int ans[N], sava[N], savb[N];
void div1(int id, int l, int r) {
	A[id] = Id, Id += r - l + 2;
	if(l == r) return A[id][0] = mod - l, A[id][1] = 1, void();
	int mid = (l + r) >> 1;
	div1(id << 1, l, mid), div1(id << 1 | 1, mid + 1, r);
	up(r - l + 1), revlim(), cle(svA), cle(svB);
	copy(A[id << 1], A[id << 1] + mid - l + 2, svA), copy(A[id << 1 | 1], A[id << 1 | 1] + r - mid + 1, svB);
	NTT(svA, 1), NTT(svB, 1);
	L(i, 0, lim - 1) svA[i] = (ll) svA[i] * svB[i] % mod;
	NTT(svA, -1), copy(svA, svA + r - l + 2, A[id]);
}
void div2(int id, int l, int r, int *f) {
	if(l == r) return ans[l] = f[0], void();
	int len = r - l + 1, mid = (l + r) >> 1, La = mid - l + 1, Lb = r - mid, *a1, *a2;
	a1 = Id, Id += La + 1, a2 = Id, Id += Lb + 1;
	up(len + Lb), cle(sava), cle(savb);
	div(f, A[id << 1], savb, sava, len, La + 1);
	copy(sava, sava + La, a1), copy(savb, savb + Lb, a2);
	div2(id << 1, l, mid, a1), div2(id << 1 | 1, mid + 1, r, a2);
}
int n, m, f[N], g[N], ansa[N], ansb[N];
int main() {
	n = read(), init(n * 2);
	L(i, 0, n - 1) f[i] = read();
	div1(1, 0, n - 1), div2(1, 0, n - 1, f);
	L(i, 0, n - 1) print(ans[i]), putchar(' ');
	return 0;
} 

其实下降幂多项式转普通多项式也可以分治 \(\tt FFT\),但那道题的题解区里有且比较显然就不讲了。

祝大家学习愉快!

posted @ 2021-01-23 11:40  zhoukangyang  阅读(335)  评论(0编辑  收藏  举报