转置原理学习笔记
本文参考 wangrx 浅谈转置原理 和 Vocalise 的博客。
1.矩阵的初等变换
也是高斯消元的基础。
1.1 定义
对矩阵施以下三种变换,称为矩阵的初等变换 :
- 交换矩阵的两行(列)
- 以一个非零数 \(k\) 乘矩阵的某一行(列)
- 把矩阵的某一行(列)的 \(l\) 倍加于另一行(列)
对单位矩阵 \(I\) 施以一次初等变换得到的矩阵,称为初等矩阵。
1.2 一些定理
设 \(A_{m\times n}=(a_{ij})_{m\times n}\)
-
定理 1 :对 \(A\) 的行施以一次初等变换,等同于用同种 \(m\) 阶初等矩阵左乘 \(A\)。对 \(A\) 的列施以一次初等变换,等同于用同种 \(n\) 阶初等矩阵右乘 \(A\)。
Proof:将 \(A\) 和 \(I\) 矩阵分块即可。
容易验证,初等矩阵都是可逆的,且它们的逆仍是初等矩阵。
\(I(ij)^{-1}=I(ij),I(i(k))^{-1}=I(i(\dfrac{1}{k})),I(ij(l))^{-1}=I(ij(-l))\)
-
定理 2 :\(n\) 阶矩阵 \(A\) 可逆的充要条件是它可以写成一些初等矩阵的乘积。
因此,\(A\) 可以表示成若干个初等矩阵的乘积。
2.算法的转置
2.1 定义
将一个算法看做方阵 \(A\),输入向量为 \(\vec v\),输出向量为 \(A \vec v\),则称该算法为线性算法。
许多算法都是线性算法,例如 FFT 的单位根矩阵。
2.2 转置的定义和性质
当输入为 \(\vec v\),输出为 \(A^T\vec v\) 的算法称为该算法的转置算法。
转置的几个性质:
-
\((AB)^T=B^TA^T\)
-
对于初等矩阵的转置: 前两种的转置是它本身。最后一种如果是将第 \(i\) 行(列)的 \(k\) 倍加到第 \(j\) 行,
则其转置为将第 \(j\) 行(列)的 \(k\) 倍加到第 \(i\) 行(列)。
2.3 线性算法的判定
但算法的实际流程并未与矩阵建立起对应的方式。我们需要找到一种途径来刻画这个过程。
将输入,输出变量和运行过程中的一切辅助变量(也就是整个内存)拼成一个向量。
也就是已知 \(\begin{bmatrix} \vec v \\ \vec 0 \\ \vec 0 \end{bmatrix}\), 求 \(\begin{bmatrix} \vec 0 \\ \vec 0 \\ A \vec v\end{bmatrix} = \begin{bmatrix} O \ \ \ O \ \ \ O \\ O \ \ \ O \ \ \ O \\ A \ \ \ O \ \ \ O\end{bmatrix}\begin{bmatrix}\vec v \\ \vec 0 \\ \vec 0\end{bmatrix}\).
若 \(A\) 不是方阵,也可以通过补 \(0\) 补成方阵。
称向量中的量为变量,矩阵中的量为常量。
具体的,线性算法只包含一下运算:
- 交换两个变量。
- \(x \leftarrow m \times x\),\(x\) 为变量, \(m\) 为常量。
- \(x \leftarrow x+ m \times y,\)\(x, y\) 为变量,\(m\) 为常量。
2.4 原算法与转置算法的转换
将矩阵分解为初等矩阵,设 \(A = E_1E_2 \cdots E_{n}\),则有 \(A^T = E_n^TE_{n-1}^T \cdots E_1^T\)。
由转置的性质可知,将算法转置相当于将运算顺序反过来,再讲 \(x \leftarrow x +my\),改写成 \(y \leftarrow y + mx\) 即可。
3.常见算法的转置
推荐阅读:127: 浅谈转置原理
3.1 前缀和
前缀和问题 :\(b_i = \sum\limits_{j \leq i}a_i\),转置为后缀和问题 \(a_i=\sum\limits_{i \leq j}b_j\) 。
for(int i = 1; i <= n; ++i) a[i] += a[i - 1];
首先将执行顺序倒序,变成 :
for(int i = n; i >= 1; --i) \\ do something..
然后将 \(a_i \leftarrow a_i +a_{i - 1}\) 转置为 \(a_{i-1} \leftarrow a_{i-1}+a_i\)。
for(int i = n; i >= 1; --i) a[i - 1] += a[i];
3.2 DFT
\(n\) 阶 DFT 形如 :\(F_{i,j}=(\omega_n^i)^j\) ,其转置为本身,即 \(F^T = F\),\(F_{i,j}^{-1}=2^{-n}(\omega_n^j)^i\)。
原算法的核心步骤是 : \(\begin{bmatrix} 1 \ \ \ w_j \\ 1 \ -w_j\end{bmatrix}\begin{bmatrix}a_{j+k} \\ a_{j+k+i} \end{bmatrix}\),转置为 : \(\begin{bmatrix} 1 \ \ \ \ 1 \\ w_j \ -w_j\end{bmatrix}\begin{bmatrix} a_{j+k} \\ a_{j+k+i} \end{bmatrix}\)
inline void DFT(int *a, int n) {
for(int i = n >> 1; i; i >>= 1) {
int x = qpow(3, (mod - 1) / (i << 1), mod);
for(int j = 0; j < n; j += (i << 1)) {
int y = 1;
for(int k = 0; k < i; ++k) {
int p = a[j + k], q = a[j + k + i];
a[j + k] = (p + q) % mod;
a[j + k + i] = 1ll * (p - q + mod) * y % mod;
}
}
}
for(int i = 0; i < n; ++i) if(i < rev[i]) std :: swap(a[i], a[rev[i]]);
}
考虑在 IDFT 时使用原算法,这样我们就免去了两次位逆序的过程。
3.3 多项式乘法
原算法:已知 \(n+1\) 维向量 \(\vec a\),\(m+1\) 维向量 \(\vec b\),\(\vec a\) 是变量, \(\vec b\) 是常量。求 \(n+m+1\) 维向量 \(\vec c\),使得:
\(c_k = \sum\limits_{i}a_ib_{k-i}\)
即 \(c_k \leftarrow c_k+a_ib_{k-i}\),转置为 : \(a_i \leftarrow a_i+c_kb_{k-i}\)。
即已知 \(n+m+1\) 维向量 \(\vec c\), \(m\) 维向量 \(\vec b\),求 \(\vec c \times \vec b^R\) 的 第 \(m\) 到 \(m+n\) 个元素。
是一个差卷积的形式,将 \(b\) 翻转后卷积即可。
由转置的过程本身可知, 原算法与转置算法的复杂度相同。
若得到计算 \(A \vec v\) 的优化方法,则必然可以通过如上机械地改写得到 \(A^T \vec v\) 的优化算法。
这便是转置原理,又称特勒根原理。
3.4 多项式多点求值
问题:给定 \(n\) 次多项式 \(f(x)\),对于每个 \(i \in [1,m]\),求出 \(f(a_i)\)。
即 \(f(a_i) = \sum\limits_{i=0}^{n-1}f_ia_i^i\)。
将 \(f_i\) 看作输入向量,则 \(C\) 即为 :
将矩阵补成方阵,转置得到:
即 $g_i = \sum\limits_{j=0}{n}f_ja_ji $,写成生成函数可以得到 : \(G(x) = \sum\limits_{i=0}^{n}\dfrac{f_i}{1 - a_ix}\)
这个 \(G(x)\) 可以通过分治记录下分子 \(f_{L,R}\) 和分母 \(g_{L,R}\) 得到:
\(g_{L,R}=g_{L,mid} \times g_{mid+1,R}\)
\(f_{L,R}=f_{L,mid} \times g_{mid+1,R} +f_{mid+1,R} \times g_{L,mid}\)
转化为线性算法的标准形式:
-
自下而上初始化常量 \(x_i\)。\(\vec g_{L,L} = \begin{bmatrix}1 \\ -x_i \end{bmatrix}\),\(\vec g_{L,R} = \vec g_{L,mid} \times g_{mid+1,R}\)。
-
初始化向量为 \(\begin{bmatrix} \vec f \\ \vec 0\end{bmatrix}\)。
-
赋值 \(\vec f_{L,L} \leftarrow f_L\)。
-
自下而上分治 : \(\vec f_{L,mid} \leftarrow \vec f_{L,mid} \times \vec g_{mid+1,R}\),\(\vec f_{mid+1,R} \leftarrow \vec f_{mid+1,R} \times \vec g_{L,mid}\)。
\(\vec f_{L,R} \leftarrow \vec f_{L,R}+\vec f_{L,mid}\),\(\vec f_{L,R} \leftarrow \vec f_{L,R}+\vec f_{mid+1,R}\)。
-
最后计算 \(\vec f \leftarrow \vec f_{1,n} \times g_{1,n}^{-1}\)
-
将辅助变量清零。
转置算法:
-
初始化常量 \(\vec g_{L,R}\)。
-
初始化向量为 \(\begin{bmatrix} \vec f \\ \vec 0\end{bmatrix}\)。
-
将辅助变量清零。
-
计算 \(\vec f_{1,n} \leftarrow \vec f \times ^T g_{1,n}^{-1}\),\(\times ^T\) 为多项式乘法的转置。
-
自上而下分治:
\(\vec f_{mid+1,R} \leftarrow \vec f_{mid+1,R} + \vec f_{L,R}\)
\(\vec f_{L,mid} \leftarrow \vec f_{L,mid}+\vec f_{L,R}\)
\(\vec f_{mid+1,R} \leftarrow f_{mid+1,R} \times \vec g_{L,mid}\)
\(\vec f_{L,mid} \leftarrow \vec f_{L,mid} \times \vec g_{mid+1,R}\)
-
赋值 \(f_i \leftarrow \vec f_{L,L}\)。
P5050 【模板】多项式多点求值 代码如下:
#include <bits/stdc++.h>
const int M = 3e5 + 5;
const int INF = 0x3f3f3f3f;
const int mod = 998244353, G = 3, InvG = (mod + 1) / G;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef std :: vector < int > Poly;
int n, m;
int rev[M];
inline int qpow(int a, int b, int p) {
int s = 1;
for(int bas = a; b; b >>= 1, bas = 1ll * bas * bas % p)
if(b & 1) s = 1ll * s * bas % p;
return s;
}
inline int add(int x, int y) {return x + y >= mod ? x + y - mod : x + y;}
inline int dec(int x, int y) {return x < y ? x - y + mod : x - y;}
inline void DFT(Poly &a) {
const int n = a.size();
for(int i = 0; i < n; ++i) rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? n >> 1 : 0);
for(int i = 0; i < n; ++i) if(i < rev[i]) std :: swap(a[i], a[rev[i]]);
for(int i = 1; i < n; i <<= 1) {
int x = qpow(G, (mod - 1) / (i << 1), mod);
for(int j = 0; j < n; j += (i << 1)) {
int y = 1;
for(int k = 0; k < i; ++k, y = 1ll * y * x % mod) {
int p = a[j + k], q = 1ll * y * a[j + k + i] % mod;
a[j + k] = add(p, q), a[j + k + i] = dec(p, q);
}
}
}
}
inline void IDFT(Poly &a) {
const int n = a.size();
for(int i = 0; i < n; ++i) rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? n >> 1 : 0);
for(int i = 0; i < n; ++i) if(i < rev[i]) std :: swap(a[i], a[rev[i]]);
for(int i = 1; i < n; i <<= 1) {
int x = qpow(InvG, (mod - 1) / (i << 1), mod);
for(int j = 0; j < n; j += (i << 1)) {
int y = 1;
for(int k = 0; k < i; ++k, y = 1ll * y * x % mod) {
int p = a[j + k], q = 1ll * y * a[j + k + i] % mod;
a[j + k] = add(p, q), a[j + k + i] = dec(p, q);
}
}
}
int inv = qpow(n, mod - 2, mod);
for(int i = 0; i < n; ++i) a[i] = 1ll * a[i] * inv % mod;
}
inline Poly Mul(Poly a, Poly b) {
const int n = a.size() + b.size() - 1; int lim;
for(lim = 1; lim < n; lim <<= 1);
a.resize(lim), b.resize(lim), DFT(a), DFT(b);
for(int i = 0; i < lim; ++i) a[i] = 1ll * a[i] * b[i] % mod;
IDFT(a); a.resize(n);
return a;
}
inline Poly MulT(Poly a, Poly b) {
const int n = a.size(), m = b.size();
std :: reverse(b.begin(), b.end());
b = Mul(a, b);
for(int i = 0; i < n; ++i) a[i] = b[i + m - 1];
return a;
}
Poly tmp;
inline Poly Inv(Poly a, int n) {
if(n == 1) return Poly(1, qpow(a[0], mod - 2, mod));
Poly b = Inv(a, n + 1 >> 1);
int lim; for(lim = 1; lim <= 2 * n; lim <<= 1);
tmp.resize(lim), b.resize(lim); for(int i = 0; i < n; ++i) tmp[i] = a[i];
for(int i = n; i < lim; ++i) tmp[i] = 0;
DFT(tmp), DFT(b);
for(int i = 0; i < lim; ++i) b[i] = 1ll * b[i] * dec(2, 1ll * b[i] * tmp[i] % mod) % mod;
IDFT(b), b.resize(n); return b;
}
Poly P[M], f, a, v;
inline void build_G(int p, int l, int r, Poly &a) {
if(l == r) {
P[p].push_back(1);
if(a[l]) P[p].push_back(mod - a[l]);
else P[p].push_back(0);
return ;
}
int mid = l + r >> 1;
build_G(p << 1, l, mid, a), build_G(p << 1 | 1, mid + 1, r, a);
P[p] = Mul(P[p << 1], P[p << 1 | 1]);
}
inline void build_F(int p, int l, int r, Poly F, Poly &v) {
F.resize(r - l + 1);
if(l == r) return v[l] = F[0], void();
int mid = l + r >> 1;
build_F(p << 1, l, mid, MulT(F, P[p << 1 | 1]), v);
build_F(p << 1 | 1, mid + 1, r, MulT(F, P[p << 1]), v);
}
inline void MultiPoint(Poly f, Poly a, Poly &v, int n) {
f.resize(n + 1), a.resize(n);
build_G(1, 0, n - 1, a), v.resize(n), build_F(1, 0, n - 1, MulT(f, Inv(P[1], n + 1)), v);
}
inline int read() {
int f = 1, s = 0; char ch = getchar();
while(!isdigit(ch)) (ch == '-') && (f = -1), ch = getchar();
while(isdigit(ch)) s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar();
return f * s;
}
int main()
{
n = read() + 1, m = read();
for(int i = 0, x; i < n; ++i) x = read(), f.push_back(x);
for(int i = 0, x; i < m; ++i) x = read(), a.push_back(x);
MultiPoint(f, a, v, std :: max(n, m));
for(int i = 0; i < m; ++i) printf("%d\n", v[i]);
return 0;
}