【luogu P4238】【模板】多项式乘法逆(NTT)(倍增)
【模板】多项式乘法逆
题目链接:luogu P4238
题目大意
给你一个多项式 F(x),要你求一个多项式 G(x),使得 F(x)*G(x)≡1(mod x^n),系数对 998244353 取模。
思路
考虑用倍增的方法。
首先只有 \(1\) 项的时候,就直接是逆元。
然后我们考虑已知 \(0\sim \frac{n}{2}\) 项的答案 \(G'(x)\),然后求 \(0\sim n\) 的答案 \(G(x)\)。(不难看出后面增加不会影响前面的)
那就有 \(G'(x)=G(x)\pmod{x^{\frac{n}{2}}}\)
即 \(G'(x)-G(x)=0\pmod{x^{\frac{n}{2}}}\)
我们考虑把 \(\pmod{x^{\frac{n}{2}}}\) 变成 \(\pmod{x^{n}}\),用平方。
\((G'(x)-G(x))^2=0\pmod{x^{n}}\)
\(G'(x)^2-2G(x)G'(x)+G(x)^2=0\pmod{x^{n}}\)
然后每一项和右边的 \(0\) 都乘上 \(F(x)\):
\(F(x)G'(x)^2-2G'(x)+G(x)=0\pmod{x^{n}}\)
\(G(x)=2G'(x)-F(x)G'(x)^2\pmod{x^{n}}\)
然后就可以用多项式 NTT 来算右边,然后直接 \(O(n)\) 加就好了。
然后好像有一个小小可以优化的方法:
你先算 \(F(x)G'(x)\),然后再乘上 \(G'(x)\)。
然后由于我们只要 \(\frac{n}{2}\sim n-1\) 项的。
你第一次乘的时候前面 \(0\sim \frac{n}{2}-1\) 项都是 \(0\)(因为它们在 \(\bmod{x^{\frac{n}{2}}}\) 的时候是 \(0\))
所以前面可以直接省去,然后第二次也可以直接省去前面的。
所以就不用每次都扩大多项式的范围,就可以节省时间。
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
#define mo 998244353
#define clear(f, n) memset(f, 0, n * sizeof(ll))
#define cpy(f, g, n) memcpy(f, g, n * sizeof(ll))
using namespace std;
int n, an[300001], limit, l_size;
ll f[300001], G, Gv;
ll w[300001], r[300001], tmp[300001];
ll ksm(ll x, ll y) {
ll re = 1;
while (y) {
if (y & 1) re = re * x % mo;
x = x * x % mo;
y >>= 1;
}
return re;
}
void get_an() {
for (int i = 0; i < limit; i++)
an[i] = (an[i >> 1] >> 1) | ((i & 1) << (l_size - 1));
}
void NTT(ll *f, ll op) {//NTT
get_an();
for (int i = 0; i < limit; i++)
if (i < an[i]) swap(f[i], f[an[i]]);
for (int mid = 1; mid < limit; mid <<= 1) {
ll Wn = ksm(op == 1 ? G : Gv, (mo - 1) / (mid << 1));
for (int R = (mid << 1), j = 0; j < limit; j += R) {
ll w = 1;
for (int k = 0; k < mid; k++, w = w * Wn % mo) {
ll x = f[j + k], y = w * f[j + mid + k] % mo;
f[j + k] = (x + y) % mo;
f[j + mid + k] = (x - y + mo) % mo;
}
}
}
if (op == -1) {//在这里直接乘了
ll liv = ksm(limit, mo - 2);
for (int i = 0; i < limit; i++)
f[i] = f[i] * liv % mo;
}
}
void px(ll *x, ll *y) {
for (int i = 0; i < limit; i++)
x[i] = x[i] * y[i] % mo;
}
void cheng(ll *x, int n, ll *y, int m) {//只是写着,并没有用
limit = 1; l_size = 0;
while (limit < n + m + 1) {
limit <<= 1; l_size++;
}
NTT(x, 1); NTT(y, 1);
px(x, y); NTT(x, -1);
}
void invp(ll *F, int n) {
w[0] = ksm(F[0], mo - 2);
l_size = 0;
for (int len = 2; (len >> 1) <= n; len <<= 1) {//倍增
limit = len; l_size++;
for (int i = 0; i < (len >> 1); i++) r[i] = w[i];
cpy(tmp, F, len);//按着操作来把三个乘上
NTT(tmp, 1); NTT(r, 1);
px(r, tmp); NTT(r, -1);
clear(r, (len >> 1));
cpy(tmp, w, len);
NTT(tmp, 1); NTT(r, 1);
px(r, tmp); NTT(r, -1);
for (int i = len >> 1; i < len; i++)//按着公式弄
w[i] = (w[i] * 2 - r[i] + mo) % mo;
}
cpy(F, w, n);
clear(tmp, n); clear(w, n); clear(r, n);//清空为好
}
int main() {
G = 3; Gv = ksm(G, mo - 2);
scanf("%d", &n);
for (int i = 0; i < n; i++) scanf("%lld", &f[i]);
invp(f, n);
for (int i = 0; i < n; i++) printf("%lld ", f[i]);
return 0;
}