在任意代数结构上的多项式乘法 学习笔记
前言
Stop learning useless algorithms, go and solve some problems, learn how to use binary search.
以下内容大多是作者看完《如何在任意代数结构上做多项式乘法》[1] 后口胡的,所以可能和原文章不太一样。如果错了或者有更好的做法请告诉我。
分圆多项式
定义为 \(\Phi_n(x) = \prod_{1 \le k \lt n,\gcd(k, n)=1}(x - \omega_n^k)\).
也可以感性理解为 \(x=\omega_n\) 时 \(x^n-1=0\),约掉一些“显然”不为 \(0\) 的因式后剩下的素多项式。
分圆多项式都是整系数素多项式,且 \(\Phi_n(x)\) 最高次数为 \(\varphi(n)\)。
结论:多项式 \(f(x)\) 代入 \(x=\omega_n\) 后做运算得到的结果(用 \(\omega_n\) 表示,最后再把 \(\omega_n\) 换成 \(x\))等于先做运算再 \(\bmod \Phi_n(x)\) 得到的结果。感性理解就是 \(\Phi_n(\omega_n)=0\)。
当 \(n=p^m\)(\(p\) 为素数,\(m \ge 1\))时,\(\Phi_n(x)=\sum\limits_{i=0}^{\varphi(n)}x^i\)。
算法原理
要求:三个群 \((A,+_A),(B,+_B),(C,+_C)\),乘法运算 \(\cdot:A\times B \rightarrow C\) 具有分配律 \((a_1 +_A a_2) \cdot (b_1 +_B b_2) = a_1 \cdot b_1 +_C a_1 \cdot b_2 +_C a_2 \cdot b_1 +_C a_2 \cdot b_2\)。
此时必有 \(\forall b \in B, e_A \times b = e_c\) 和 \(\forall a \in A,a \times e_b = e_c\),其中 \(e_A,e_B,e_C\) 分别为 \(A,B,C\) 中的单位元。于是将 \(A\),\(B\) 的高位填对应的单位元即可。
证明:
\(a \cdot b = (a+_Ae_A) \cdot b = a\cdot b +_C e_A \cdot b\),两边加上 \(a \cdot b\) 在 \(C\) 中的逆元即可。
\(e_B\) 是类似的。
\(C\) 最好还能有较快(\(O(1)\))的自然数乘,定义为多个相同的元素加在一起。
Part 1. 解决除法
IDFT 最后要除以长度,而 \(C\) 中没有定义自然数乘的逆。
一个解决方法是,分别做长为 \(2\) 的幂的 DFT 和长为 \(3\) 的幂的 DFT,这样每个元素的 \(2^{c_2}\) 倍和 \(3^{c_3}\) 倍都已知(\(c_2\) 和 \(c_3\) 取决于长度),类似辗转相除做即可。
Part 2. 解决单位根
这是一个很神仙的做法。
考虑把一部分 \(x\) 代入 \(\omega_m\) 满足 \(\varphi(m) \gt \deg A(x)+\deg B(x)\),然后将 \(m\) 拆成 \(m=pq\)。取 \(p=q=\sqrt m\) 能保证最优复杂度,读者自(wo)证(bu)不(hui)难(zheng)。具体实现可以参考代码。
具体地,
然后将内层带 \(\omega_p\) 的东西看成系数对外层做 DFT。实现时可以做成指针套数组的形式。这个部分可能不太好理解,可以看代码。
做完 DFT 要进行内层元素相乘,可以递归。
最后对分圆多项式取模即可。实现时可以暴力将高位减到低位。
应用
好像没啥用...
目前想到的就是做 \(c_k = \prod_{i+j=k}a_i^{b_j}\) 之类的卷积?
实现
给出一份大常数的实现。期待有大佬能优化。
题目是 lgP3803.
#define DEBUG 0
#include <iostream>
#include <algorithm>
#include <cmath>
#define UP(i,s,e) for(auto i=s; i<e; ++i)
#define DOWN(i,e,s) for(auto i=e; i-->s;)
using std::cin; using std::cout;
namespace Poly{ // }{{{
template<int BASE, typename T>
void change(T* arr, int len){
int *rev = new int[len];
rev[0] = 0;
UP(i, 1, len){
rev[i] = rev[i/BASE]/BASE;
rev[i] += i%BASE*(len/BASE);
}
UP(i, 0, len) if(rev[i] > i) std::swap(arr[i], arr[rev[i]]);
delete[] rev;
}
template<int BASE, class A>
void fft(A **a, int len, int siz, bool idft){ // siz == len(a[0])
static A *tmp[BASE];
UP(i, 0, BASE){
tmp[i] = new A[siz];
//UP(j, siz, siz*BASE){
// tmp[i][j].unit();
//}
}
change<BASE>(a, len);
int wn = siz/BASE;
for(int h=BASE; h<=len; h*=BASE){
for(int st=0; st<len; st+=h){
int w=0;
UP(i, st, st+h/BASE){
UP(j, 0, BASE) std::swap(a[i+h/BASE*j], tmp[j]);
UP(j, 0, BASE){
auto &now = a[i+h/BASE*j];
std::copy(tmp[0], tmp[0]+siz, now);
UP(k, 1, BASE){
UP(l, 0, siz){
int idx = l-(idft?-1:1)*(w+siz/BASE*j)*k;
idx %= siz; idx = idx < 0 ? idx + siz : idx;
now[l] += tmp[k][idx];
}
}
}
w += wn;
}
}
wn /= BASE;
}
UP(i, 0, BASE) delete[] tmp[i];
//delete[] tmp;
}
// mod Phi_len(x)
// len = BASE**n
template<int BASE, class A, class B, class C>
int polymul_base(A *a, B *b, C *ret, int len
#if DEBUG
, int test=0
#endif
){
UP(i, 0, len/BASE*(BASE-1)) ret[i].unit();
if(len < 100
#if DEBUG
&& !test
#endif
){
int phi_len = len / BASE * (BASE-1);
UP(i, 0, len) UP(j, 0, len){
if((i+j)%len >= phi_len) UP(k, 1, BASE){
ret[(i+j)%len-len/BASE*k] += (a[i]*b[j]).inv();
} else {
ret[(i+j)%len] += a[i]*b[j];
}
}
return 1;
}
int tim = std::round(std::log(len)/std::log(BASE));
int p = std::round(std::pow(BASE, tim/2+1));
int q = std::round(std::pow(BASE, (tim-1)/2));
A **aa = new A*[BASE*q];
B **bb = new B*[BASE*q];
C **cc = new C*[BASE*q];
UP(i, 0, BASE*q){
aa[i] = new A[p];
bb[i] = new B[p];
cc[i] = new C[p];
}
UP(i, 0, q*BASE) UP(j, 0, p){
aa[i][j].unit(); bb[i][j].unit();// cc[i][j].unit();
}
UP(i, 0, q*BASE) UP(j, p/BASE*(BASE-1), p){
cc[i][j].unit();
}
UP(i, 0, q){
UP(j, 0, p){
if(j*q+i >= len){
break;
//aa[i][j].unit(); bb[i][j].unit();
}
else {
aa[i][j] = a[j*q+i];
bb[i][j] = b[j*q+i];
}
}
//UP(j, p/BASE*(BASE-1), p){ aa[i][j].unit(); bb[i][j].unit(); }
}
//UP(i, q, BASE*q){
//UP(j, 0, p){ aa[i][j].unit(); bb[i][j].unit(); }
//}
fft<BASE>(aa, BASE*q, p, false);
fft<BASE>(bb, BASE*q, p, false);
int scale;
UP(i, 0, BASE*q){
scale = polymul_base<BASE>(aa[i], bb[i], cc[i], p
#if DEBUG
, test ? test-1 : 0
#endif
);
}
UP(i, 0, BASE*q){
delete[] aa[i];
delete[] bb[i];
}
delete[] aa;
delete[] bb;
fft<BASE>(cc, BASE*q, p, true);
int pq = p*q;
int phi_pq = pq/BASE*(BASE-1);
UP(i, 0, BASE*q) UP(j, 0, p){
int pl = (i+j*q)%pq;
if(pl >= phi_pq) UP(k, 1, BASE) ret[(pl-pq/BASE*k)%len] += cc[i][j].inv();
else ret[pl%len] += cc[i][j];
}
UP(i, 0, BASE*q) delete[] cc[i];
delete[] cc;
return scale * BASE * q;
}
template<class A, class B, class C>
void polymul(A *a, B *b, C *ret, int len){
bool swapped = false;
C *tmp = new C[len*2];
int l2 = std::round(std::pow(2, std::ceil(std::log(len*2) / std::log(2))));
int l3 = std::round(std::pow(3, std::ceil(std::log(len*3/2) / std::log(3))));
int tim2 = polymul_base<2>(a, b, tmp, l2);
int tim3 = polymul_base<3>(a, b, ret, l3);
while(tim3 != 1){
if(tim2 > tim3){
int scale = tim2 / tim3;
UP(i, 0, len) tmp[i] += ret[i].inv() * scale;
tim2 %= tim3;
}
std::swap(tim2, tim3);
std::swap(ret, tmp);
swapped ^= 1;
}
if(swapped){ std::swap(ret, tmp); std::copy(tmp, tmp+len, ret); }
delete[] tmp;
}
} // {}}}
namespace m{ // }{{{
constexpr int N = 5e6+2;
struct u32{
unsigned val;
u32(){}
u32(unsigned v):val(v){}
void unit(){val = 0;}
u32 inv(){ return -val; }
u32 &operator+=(u32 b){ val += b.val; return *this; }
u32 &operator*=(u32 b){ val *= b.val; return *this; }
u32 operator*(u32 b){ return b *= *this;}
u32 operator*(unsigned x){ return val*x; }
} ia[N], ib[N], ic[N];
int in, im;
void work(){
cin >> in >> im;
UP(i, 0, in+1){
cin >> ia[i].val;
}
UP(i, 0, im+1){
cin >> ib[i].val;
}
Poly::polymul(ia, ib, ic, in+im+1);
UP(i, 0, in+im+1){
cout << ic[i].val << ' ';
}
}
} // {}}}
int main(){cin.tie(0)->sync_with_stdio(0); m::work(); return 0;}
update:自己造了道题,欢迎爆踩标算(
附代码(仅给出不同部分):
namespace m{ // }{{{
constexpr int N = 5e6+2;
using u32 = unsigned;
using ll = long long;
u32 qpow(u32 x, u32 t){
u32 ans = 1;
while(t){
if(t&1){ ans = ans * x ; }
x = x * x;
t >>= 1;
}
return ans;
}
struct Modbase{
u32 val;
Modbase(){}
Modbase(u32 x):val(x){}
Modbase &operator+=(Modbase x){ val = x.val*val; return *this; }
Modbase operator*(int x){ return qpow(val, x); }
Modbase inv(){ return qpow(val, (~0u)>>1); }
void unit(){ val = 1;}
} ia[N], ic[N];
struct Modexp{
u32 val;
Modexp &operator+=(Modexp x){ val += x.val; return *this; }
void unit(){ val = 0; }
} ib[N];
Modbase operator*(Modbase x, Modexp y){
return qpow(x.val, y.val&((~0u)>>1));
}
int in, im;
void work(){
cin >> in >> im;
UP(i, 0, in){
cin >> ia[i].val;
}
UP(i, 0, im){
cin >> ib[i].val;
}
int len = in+im-1;
int l2 = std::round(std::pow(2, std::ceil(std::log(len*2) / std::log(2))));
int l3 = std::round(std::pow(3, std::ceil(std::log(len*3/2) / std::log(3))));
len = std::max(l2, l3);
UP(i, in, len) ia[i].unit();
Poly::polymul(ia, ib, ic, in+im-1);
UP(i, 0, in+im-1){
cout << ic[i].val << ' ';
}
cout << '\n';
}
} // {}}}