一些模板(持续更新中)
诶,开这篇的原因之一是要学习多项式全家桶,想找个地方存板子,原因之二是发现自己以前的一些模板已经非常不熟悉了,甚至于一些细节已经不知道为什么这么写了。
可能这里应该放一些代码,更多理论知识请查看学习笔记。
多项式全家桶
多项式乘法
NTT
for (int len = 2; len <= lim; len <<= 1) {
int half = len >> 1, w = (type > 0) ? g[cnt] : invg[cnt];
for (int bg = 0; bg < lim; bg += len) {
int wn = 1;
for (int pos = bg; pos < bg + half; ++pos) {
int tmp = Mul(wn, poly[pos + half]);
poly[pos + half] = Dec(poly[pos], tmp);
poly[pos] = Add(poly[pos], tmp);
wn = Mul(wn, w);
}
}
++cnt;
}
我将对上面这段代码,我产生过的疑问进行记录。
- 为什么是
g[cnt]
?我怎么知道是?- 因为
g[x]
能解决的多项式项数是 \(2^x\) 次多项式的点值转化。
- 因为
完整代码:
#include <cstdio>
#include <iostream>
using namespace std;
typedef long long ll;
const int Mod = 998244353, G = 3, N = 1e6 + 5;
inline int Rd() {
int ret = 0, fu = 1;
char ch = getchar();
while (!isdigit(ch)) {
if (ch == '-')
fu = -1;
ch = getchar();
}
while (isdigit(ch))
ret = ret * 10 + (ch - '0'), ch = getchar();
return ret * fu;
}
inline int Add(int x, int y) { return (x + y > Mod) ? (x + y - Mod) : (x + y); }
inline int Dec(int x, int y) { return (x - y < 0) ? (x - y + Mod) : (x - y); }
inline int Mul(int x, int y) { return 1ll * x * y % Mod; }
inline int Pow(int x, int y) {
int ret = 1;
for (; y; y >>= 1, x = Mul(x, x))
if (y & 1)
ret = Mul(ret, x);
return ret;
}
int rev[N * 4], lim, g[25], invg[25];
void Init(int n, int m) {
lim = 1;
while (lim <= n + m)
lim <<= 1;
for (int i = 0; i < lim; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0);
for (int i = 1; (1 << i) <= lim; ++i) {
g[i] = Pow(G, (Mod - 1) / (1 << i));
invg[i] = Pow(g[i], Mod - 2);
}
}
void NTT(int *poly, int type) {
for (int i = 0; i < lim; ++i)
if (rev[i] > i)
swap(poly[i], poly[rev[i]]);
int cnt = 1;
for (int len = 2; len <= lim; len <<= 1) {
int half = len >> 1, w = (type > 0) ? g[cnt] : invg[cnt];
for (int bg = 0; bg < lim; bg += len) {
int wn = 1;
for (int pos = bg; pos < bg + half; ++pos) {
int tmp = Mul(wn, poly[pos + half]);
poly[pos + half] = Dec(poly[pos], tmp);
poly[pos] = Add(poly[pos], tmp);
wn = Mul(wn, w);
}
}
++cnt;
}
}
int n, m, A[N * 4], B[N * 4];
int main() {
n = Rd(), m = Rd();
for (int i = 0; i <= n; ++i)
A[i] = (Rd() + Mod) % Mod;
for (int i = 0; i <= m; ++i)
B[i] = (Rd() + Mod) % Mod;
Init(n, m);
NTT(A, 1), NTT(B, 1);
for (int i = 0; i < lim; ++i)
A[i] = Mul(A[i], B[i]);
NTT(A, -1);
int inv = Pow(lim, Mod - 2);
for (int i = 0; i <= n + m; ++i)
printf("%d ", Mul(A[i], inv));
printf("\n");
return 0;
}
请推完式子再写NTT,否则容易写错!对着式子写。
FFT
注意使用double
类型!
#include <cstdio>
#include <iostream>
#include <cmath>
using namespace std;
const int N = 3e6 + 5;
const double PI = acos(-1);
struct Complex {
double x, y;
Complex(double _x, double _y) : x(_x), y(_y) {}
Complex() : x(0), y(0) {}
Complex operator+(const Complex &d) const { return Complex(x + d.x, y + d.y); }
Complex operator-(const Complex &d) const { return Complex(x - d.x, y - d.y); }
Complex operator*(const Complex &d) const { return Complex(x * d.x - y * d.y, x * d.y + y * d.x); }
} f[N], g[N]; //x为实部,y为虚部
int n, m, rev[N];
void FFT(Complex *poly, int lim, int type) {
for (int i = 0; i < lim; ++i) if (i < rev[i]) swap(poly[i], poly[rev[i]]);
for (int len = 2; len <= lim; len <<= 1) {
int half = len >> 1;
Complex gen(cos(2 * PI / len), sin(2 * PI / len) * type);
//这个地方不是 /half而是 /len
for (int bg = 0; bg < lim; bg += len) {
Complex omg(1, 0);
for (int pos = bg; pos < bg + half; ++pos) {
Complex tmp = omg * poly[pos + half];
poly[pos + half] = poly[pos] - tmp;
//这里是 pos + half,没有 -1
poly[pos] = poly[pos] + tmp;
omg = omg * gen;
}
}
}
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 0; i <= n; ++i) scanf("%lf", &f[i].x);
for (int i = 0; i <= m; ++i) scanf("%lf", &g[i].x);
int lim = 1, len = n + m + 1;
while (lim < len) lim = lim << 1;
for (int i = 1; i < lim; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0);
FFT(f, lim, 1);
FFT(g, lim, 1);
for (int i = 0; i < lim; ++i) f[i] = f[i] * g[i];
FFT(f, lim, -1);
for (int i = 0; i <= n + m; ++i) printf("%d ", (int) (f[i].x / lim + 0.49));
return 0;
}
三模NTT(任意模数NTT)
实现要些小技巧,就是封装一个类struct Int
。这样只需要写一个NTT就行了。
#include <cstdio>
#include <iostream>
using namespace std;
typedef long long ll;
const int N = (1<<18)+5, M1 = 998244353, M2 = 1004535809, M3 = 469762049;
int n, m, mod;
int ad(int x, int y, int z) { return (x+y>z) ? (x+y-z) : (x+y); }
int dc(int x, int y, int z) { return (x-y<0) ? (x-y+z) : (x-y); }
int ml(int x, int y, int z) { return 1ll * x * y % z; }
int ksm(int x, int y, int z) {
int ret = 1;
for (; y; y >>= 1, x = ml(x, x, z))
if (y & 1)
ret = ml(ret, x, z);
return ret;
}
const int IV1 = ksm(M1, M2-2, M2), IV2 = ksm(ml(M1, M2, M3), M3-2, M3);
const ll T = 1ll*M1*M2;
struct Int {
int r1, r2, r3;
Int(int x, int y, int z) : r1(x), r2(y), r3(z) {}
Int(int x=0) : r1(x%M1), r2(x%M2), r3(x%M3) {}
Int operator+(const Int &d) const { return Int(ad(r1, d.r1, M1), ad(r2, d.r2, M2), ad(r3, d.r3, M3)); }
Int operator-(const Int &d) const { return Int(dc(r1, d.r1, M1), dc(r2, d.r2, M2), dc(r3, d.r3, M3)); }
Int operator*(const Int &d) const { return Int(ml(r1, d.r1, M1), ml(r2, d.r2, M2), ml(r3, d.r3, M3)); }
int merge(int MOD) {
int k1 = ml(dc(r2, r1, M2), IV1, M2);
ll r4 = (1ll * k1 * M1 + r1) % T;
int k2 = ml(dc(r3, r4 % M3, M3), IV2, M3);
return ad(ml(k2, ml(M1, M2, MOD), MOD), r4 % MOD, MOD);
}
} F[N], G[N], g[20], ig[22];
int rev[N];
void Init() {
for (int i=0; i<=18; ++i) g[i] = Int(ksm(3, (M1-1)>>i, M1), ksm(3, (M2-1)>>i, M2), ksm(3, (M3-1)>>i, M3));
for (int i=0; i<=18; ++i) ig[i] = Int(ksm(g[i].r1, M1-2, M1), ksm(g[i].r2, M2-2, M2), ksm(g[i].r3, M3-2, M3));
}
int GetUp(int n) {
int up = 1;
while (up<=n) up<<=1;
return up;
}
void GetRev(int up) {
for (int i=1; i<up; ++i) rev[i] = (rev[i>>1]>>1) | ((i&1) ? (up>>1) : 0);
}
void NTT(Int *f, int type, int up) {
for (int i=0; i<up; ++i)
if (i<rev[i])
swap(f[i], f[rev[i]]);
int cnt=1;
for (int len=2; len<=up; len<<=1) {
Int gen = type>0 ? g[cnt] : ig[cnt];
int half = len>>1;
for (int bg=0; bg<up; bg+=len) {
Int now=Int(1);
for (int p=bg; p<bg+half; ++p) {
Int t=f[p+half]*now;
f[p+half]=f[p]-t;
f[p]=f[p]+t;
now=now*gen;
}
}
++cnt;
}
if (type < 0) {
Int inv = Int(ksm(up, M1-2, M1), ksm(up, M2-2, M2), ksm(up, M3-2, M3));
for (int i=0; i<up; ++i) f[i]=f[i]*inv;
}
}
int main() {
Init();
scanf("%d%d%d", &n, &m, &mod);
int x, up;
for (int i=0; i<=n; ++i) {
scanf("%d", &x);
F[i] = Int(x%mod);
}
for (int i=0; i<=m; ++i) {
scanf("%d", &x);
G[i] = Int(x%mod);
}
up=GetUp(n+m);
GetRev(up);
NTT(F, 1, up);
NTT(G, 1, up);
for (int i=0; i<up; ++i) F[i]=F[i]*G[i];
NTT(F, -1, up);
for (int i=0; i<=n+m; ++i) printf("%d ", F[i].merge(mod));
return 0;
}
多项式exp
由于多项式ln
,多项式求逆,都在多项式exp
中用过。所以这里只放多项式exp
的代码。
注释中有有可能犯的错误。
#include <cstdio>
#include <iostream>
using namespace std;
const int N = 262144 + 5, Mod = 998244353, G = 3;
int ad(int x, int y) { return (x + y > Mod) ? (x + y - Mod) : (x + y); }
int dc(int x, int y) { return (x - y < 0) ? (x - y + Mod) : (x - y); }
int ml(int x, int y) { return (long long) x * y % Mod; }
int ksm(int x, int y) {
int ret = 1;
for (; y; y >>= 1, x = ml(x, x))
if (y & 1) ret = ml(ret, x);
return ret;
}
int rev[N], g[25], invg[25];
void GetRev(int up) {
rev[0] = 0;
for (int i = 1; i < up; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (up >> 1) : 0); //后面板部分可能直接写成up >> 1(没有判i&1)
}
int GetUp(int deg) {
int up = 1;
while (up <= deg) up <<= 1;
return up;
}
void Init() {
for (int i = 0; i <= 23; ++i) g[i] = ksm(G, (Mod - 1) >> i);
for (int i = 0; i <= 23; ++i) invg[i] = ksm(g[i], Mod - 2);
}
void NTT(int *F, int up, int tp) {
for (int i = 0; i < up; ++i)
if (i < rev[i]) swap(F[i], F[rev[i]]);
int cnt = 1;
for (int l = 2; l <= up; l <<= 1) { //这里可能写成 ++l
int h = l >> 1, w = (tp > 0) ? g[cnt] : invg[cnt];
for (int b = 0; b < up; b += l) {
int wp = 1;
for (int p = b; p < b + h; ++p) {
int tmp = ml(F[p + h], wp);
F[p + h] = dc(F[p], tmp);
F[p] = ad(F[p], tmp);
wp = ml(wp, w);
}
}
++cnt;
}
if (tp < 0) {
int tmp = ksm(up, Mod - 2);
for (int i = 0; i < up; ++i) F[i] = ml(F[i], tmp);
}
}
void Clear(int *F, int L, int R) { //[L,R)
for (int i = L; i < R; ++i) F[i] = 0;
}
void Inv(const int *F, int *G0, int up, bool fir = 1) {
if (fir) Clear(G0, 0, up << 1);
if (up == 1) {
G0[0] = ksm(F[0], Mod - 2);
return ;
}
Inv(F, G0, up >> 1, 0);
static int T[N];
for (int i = 0; i < up; ++i) T[i] = F[i];
Clear(T, up, up << 1);
GetRev(up << 1);
NTT(G0, up << 1, 1);
NTT(T, up << 1, 1);
for (int i = 0; i < (up << 1); ++i) G0[i] = dc(ml(2, G0[i]), ml(ml(G0[i], G0[i]), T[i]));
//有可能后面写的 ml(ml(G0[i], G0[i]), F[i])
NTT(G0, up << 1, -1);
Clear(G0, up, up << 1);
}
void Int(int *F, int up) {
int deg = up - 1;
for (int i = deg; i >= 1; --i)
F[i] = ml(F[i - 1], ksm(i, Mod - 2));
F[0] = 0;
}
void Der(int *F, int up) {
int deg = up - 1;
for (int i = 0; i < deg; ++i)
F[i] = ml(F[i + 1], i + 1);
F[deg] = 0;
}
void Ln(const int *F, int *G0, int up) {
for (int i = 0; i < up; ++i) G0[i] = F[i];
Clear(G0, up, up << 1);
Der(G0, up);
static int T[N];
Inv(F, T, up);
NTT(G0, up << 1, 1);
NTT(T, up << 1, 1);
for (int i = 0; i < (up << 1); ++i) G0[i] = ml(G0[i], T[i]);
NTT(G0, up << 1, -1);
Clear(G0, up, up << 1);
Int(G0, up);
} //式子可能推错成 \Int \frac{1}{f(x)}
void Exp(const int *F, int *G0, int up, bool fir = 1) {
if (fir) Clear(G0, 0, up << 1);
if (up == 1) {
G0[0] = 1;
return ;
}
Exp(F, G0, up >> 1);
static int T[N];
Ln(G0, T, up);
for (int i = 0; i < up; ++i) T[i] = dc(F[i], T[i]);
T[0] = ad(T[0], 1);
NTT(G0, up << 1, 1);
NTT(T, up << 1, 1);
for (int i = 0; i < (up << 1); ++i) G0[i] = ml(G0[i], T[i]);
NTT(G0, up << 1, -1);
Clear(G0, up, up << 1);
}
int n, A[N], B[N];
int main() {
Init();
scanf("%d", &n);
--n;
for (int i = 0; i <= n; ++i) scanf("%d", &A[i]);
int up = GetUp(n);
Exp(A, B, up);
for (int i = 0; i <= n; ++i) printf("%d ", B[i]);
printf("\n");
return 0;
}
计算几何
二维凸包
注意:
- 在计算下凸壳的时候需要用最开始的那个点再弹出一些点。否则凸包上会多一个点。
- 必两个轴都排序。 \(x\) 轴为第一关键字,\(y\) 轴为第二关键字排序。
#include <cstdio>
#include <algorithm>
#include <cmath>
using namespace std;
const int N = 1e5 + 5;
const double eps = 1e-10;
bool eql(double x, double y) { return abs(x - y) < eps; }
struct Dot {
double x, y;
Dot(double _x = 0, double _y = 0) : x(_x), y(_y) {}
Dot operator-(const Dot &d) const { return Dot(x - d.x, y - d.y); }
double operator^(const Dot &d) const { return x * d.y - d.x * y; }
bool operator<(const Dot &d) const { return x < d.x; }
} dt[N];
double dis(Dot x, Dot y) { return sqrt((x.x - y.x) * (x.x - y.x) + (x.y - y.y) * (x.y - y.y)); }
int n, stk[N], tp;
bool vis[N];
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i) scanf("%lf%lf", &dt[i].x, &dt[i].y);
sort(dt + 1, dt + n + 1);
//先计算上凸壳
for (int i = 1; i <= n; ++i) {
while (tp >= 2 && ((dt[stk[tp - 1]] - dt[stk[tp]]) ^ (dt[stk[tp]] - dt[i])) > -eps)
vis[stk[tp--]] = 0;
stk[++tp] = i;
vis[i] = 1;
}
vis[1] = 0;
//在围下凸壳的时候需要用 1号点再弹出一些点
int nb = tp;
for (int i = n; i >= 1; --i) {
if (vis[i])
continue ;
while (tp > nb && ((dt[stk[tp - 1]] - dt[stk[tp]]) ^ (dt[stk[tp]] - dt[i])) > -eps)
--tp;
stk[++tp] = i;
}
double ans = 0;
for (int i = 1; i <= tp; ++i)
ans += dis(dt[stk[i]], dt[stk[i % tp + 1]]);
printf("%.2lf\n", ans);
return 0;
}
旋转卡壳
注意:
- 凸包上避免三点共线
- 注意特判只有两个点的情况
- 卡壳的时候必须是\(\le\) (注释中有)因为可能有平行的边。
#include <cstdio>
#include <algorithm>
#include <cmath>
using namespace std;
const int N = 5e4 + 5;
struct dot {
int x, y;
dot(int _x = 0, int _y = 0) : x(_x), y(_y) {}
dot operator-(const dot &d) const { return dot(x - d.x, y - d.y); }
int operator*(const dot &d) const { return x * d.y - y * d.x; }
bool operator<(const dot &d) const { return (x == d.x) ? (y < d.y) : (x < d.x); }
} dt[N];
int dis(dot a, dot b) { return (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y); }
int n, stk[N];
bool use[N];
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
scanf("%d%d", &dt[i].x, &dt[i].y);
sort(dt + 1, dt + n + 1);
int tp = 0;
for (int i = 1; i <= n; ++i) {
while (tp >= 2 && (dt[stk[tp - 1]] - dt[stk[tp]]) * (dt[stk[tp]] - dt[i]) >= 0) {
use[stk[tp]] = 0;
--tp;
}
use[i] = 1;
stk[++tp] = i;
}
use[1] = 0;
int rec = tp;
for (int i = n; i >= 1; --i) {
if (use[i]) continue ;
while (tp > rec && (dt[stk[tp - 1]] - dt[stk[tp]]) * (dt[stk[tp]] - dt[i]) >= 0)
--tp; //这里是>=0,为了避免凸包上出现三点共线。
stk[++tp] = i;
}
if (tp == 3) {
printf("%d\n", dis(dt[stk[1]], dt[stk[2]]));
return 0;
}
int ans = 0, pt = 1;
for (int i = 1; i < tp; ++i) {
int x = stk[i], y = stk[i + 1];
while ((dt[y] - dt[stk[pt]]) * (dt[x] - dt[stk[pt]]) <= (dt[y] - dt[stk[pt % tp + 1]]) * (dt[x] - dt[stk[pt % tp + 1]])) //必须是<=。
pt = pt % (tp - 1) + 1;
ans = max(ans, max(dis(dt[x], dt[stk[pt]]), dis(dt[y], dt[stk[pt]])));
}
printf("%d\n", ans);
return 0;
}