题解-Luogu P4512 多项式除法
114514 年前写的题解,搬过来了
0.前言
必备内容:NTT,多项式求逆
建议到博客中查看
1.正文
定义 \(F_R(x)\) 为把 \(F(x)\) 系数翻转后的函数
若 \(F_R(x)\) 次数为 \(k\) ,那么 \(F_R(x)=F(\frac{1}{x})*x^k\)
若 \(f(x)=g(x)*q(x)+r(x)\), \(f\) 次数为 \(n\) ,\(g\) 次数为 \(m\), \(q\) 次数为 \(n-m\) , \(r\) 次数为 \(\m-1\)
我们可以把 \(\frac{1}{x}\) 带入函数
得
\[f(\frac{1}{x})=g(\frac{1}{x})*q(\frac{1}{x})+r(\frac{1}{x})
\]
\[f(\frac{1}{x})*x^n=g(\frac{1}{x})*x^mq(\frac{1}{x})*x^{n-m}+r(\frac{1}{x})*x^{m-1}*x^{n-m+1}
\]
\[f_R(x)=g_R(x)*q_R(x)+r_R(x)*x^{n-m+1}
\]
\[f_R(x) \equiv g_R(x)*q_R(x) \pmod {x^{n-m+1}}
\]
\[q_R(x)\equiv \frac{f_R(x)}{g_R(x)} \pmod {x^{n-m+1}}
\]
从而我们可以知道 \(q(n) \bmod x^{n-m+1}\)
我们知道 \(q\) 次数为 \(n-m\) ,而模数次数为 \(n-m+1\)
于是就可以算出 \(q\) 啦!!!!!
然后我们可以用
\[r(x)=f(x)-g(x)*q(x)
\]
然后 \(r\) 也算出来啦!
2.代码
#include<bits/stdc++.h>
#define L(i, j, k) for(int i = j, i##E = k; i <= i##E; i++)
#define R(i, j, k) for(int i = j, i##E = k; i >= i##E; i--)
#define ll long long
#define ull unsigned long long
#define db double
#define pii pair<int, int>
#define mkp make_pair
using namespace std;
const int N = 524288, mod = 998244353, G = 3, iG = (mod + 1) / G;
int qpow(int x, int y = mod - 2) {
int res = 1;
for(; y; x = (ll) x * x % mod, y >>= 1) if(y & 1) res = (ll) res * x % mod;
return res;
}
int lim, pp[N];
void revlim() { L(i, 0, lim - 1) pp[i] = ((pp[i >> 1] >> 1) | ((i & 1) * (lim >> 1))); }
void up(int x) { lim = 1; for(; lim <= x; lim <<= 1); }
void cle(int *f) { L(i, 0, lim - 1) f[i] = 0; }
void NTT(int *f, int flag) {
L(i, 0, lim - 1) if(pp[i] < i) swap(f[pp[i]], f[i]);
for(int i = 2; i <= lim; i <<= 1)
for(int j = 0, l = (i >> 1), ch = qpow(flag == 1 ? G : iG, (mod - 1) / i); j < lim; j += i)
for(int k = j, now = 1; k < j + l; k ++) {
int pa = f[k], pb = (ll) f[k + l] * now % mod;
f[k] = (pa + pb) % mod, f[k + l] = (pa + mod - pb) % mod;
now = (ll) now * ch % mod;
}
if(flag == -1) {
int nylim = qpow(lim);
L(i, 0, lim - 1) f[i] = (ll) f[i] * nylim % mod;
}
}
int sav[N], sv[N];
void inv(int *f, int *g, int len) {
if(len == 1) return g[0] = qpow(f[0]), void();
inv(f, g, (len + 1) >> 1), up(len << 1), cle(sav), copy(f, f + len, sav), revlim(), NTT(sav, 1), NTT(g, 1);
L(i, 0, lim - 1) g[i] = (ll) g[i] * (2ll + mod - (ll) g[i] * sav[i] % mod) % mod;
NTT(g, -1), fill(g + len, g + lim, 0);
}
int sava[N], savb[N];
void div(int *f, int *g, int *ansa, int *ansb, int n, int m) {
up(n << 1), cle(sava), cle(savb), copy(g, g + m, sava), reverse(sava, sava + m);
inv(sava, savb, n), up(n << 1), revlim(), copy(f, f + n, sava), reverse(sava, sava + n);
NTT(sava, 1), NTT(savb, 1);
L(i, 0, lim - 1) ansa[i] = (ll) sava[i] * savb[i] % mod;
NTT(ansa, -1), fill(ansa + n - m + 1, ansa + lim, 0), reverse(ansa, ansa + n - m + 1);
cle(sav), cle(sava), copy(ansa, ansa + n - m + 1, sav), copy(g, g + m, sava), NTT(sav, 1), NTT(sava, 1);
L(i, 0, lim - 1) sav[i] = (ll) sav[i] * sava[i] % mod;
NTT(sav, -1);
L(i, 0, m - 2) ansb[i] = (f[i] + mod - sav[i]) % mod;
}
int n, m, f[N], g[N], ansa[N], ansb[N];
int main() {
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
cin >> n >> m, ++n, ++m;
L(i, 0, n - 1) cin >> f[i];
L(i, 0, m - 1) cin >> g[i];
div(f, g, ansa, ansb, n, m);
L(i, 0, n - m) cout << ansa[i] << " ";
cout << endl;
L(i, 0, m - 2) cout << ansb[i] << " ";
cout << endl;
return 0;
}
祝大家学习愉快!