转置原理
去年尝试理解过,今年再看才学懂了一点。基本抄的 EI & qwaszx 的课件。
简介
转置原理给出的是,通过 \(\mathbf{b}=A\mathbf{a}\) 的算法来解决 \(\mathbf{\hat b}=A^{T}\mathbf{\hat a}\),这里 \(\mathbf{a}\) 和 \(\mathbf{\hat a}\) 是输入给定的向量,而 \(A\) 是一个常量矩阵。
转置原理要求得到的 \(\mathbf{b}\) 中每个元素都是 \(\mathbf{a}\) 中元素的线性组合,换而言之,\(\mathbf{b}\) 中的元素(也包括算法过程中)不允许出现 \(\mathbf{a}\) 的非一次项(也就是出现 \(a_ia_j,{a_i}^{2}\) 这样的),这代表了这个算法是一个线性算法。
线性算法在分解中 \(A\) 完成计算,即计算 \(A_1\cdots A_m\mathbf{a}\),其中 \(A\) 均为初等矩阵(或者是比较简单的矩阵?)。那么其转置问题即为计算 \({A_m}^T\cdots {A_1}^T\mathbf{a}\),这称之为原算法的转置算法。也就是将原算法的操作逆序执行它们的转置,即得到了转置后的算法。
例子
例如在前缀和算法中,输入给定了需要作前缀和的 \(\mathbf{a}\),将其左乘一个矩阵得到前缀和后的结果 \(A\mathbf{a}=\mathbf{b}\),不难写出 \(A\) 是主对角线及以下均为 \(1\),其余位置都为 \(0\) 的矩阵,而将其转置后 \(A^T\) 是主对角线及以上均为 \(1\),其余位置均为 \(0\) 的矩阵,不难验证 \(A^T\mathbf{a}\) 即为对 \(\mathbf{a}\) 作后缀和,这意味着前缀和算法转置后得到了后缀和算法。
试着写写 \(b_i\gets b_i+cb_{j}\) 的转置:
转置后:
所以 \(b_i\gets b_i+cb_{j}\) 转置后得到了 \(b_{j}\gets b_j+cb_i\).
常见操作的转置:
\(a_i\gets a_i+ ca_j\) | \(a_j\gets a_j + ca_i\) |
---|---|
\(swap(a_i,a_j)\) | \(swap(a_i,a_j)\) |
\(a_i\gets a_j\) | \(a_j\gets a_i+a_j,a_i=0\) |
\(a_i\gets ca_i\) | \(a_i\gets ca_i\) |
FFT
我们都知道 FFT 是将 \(n\) 次单位根的若干次幂代入多项式得到点值。不难写出其对应的矩阵 \(A\) 为 \(\omega ^{ij}\),那么就有 \(A=A^T\),这意味着将 FFT 的转置算法和 FFT 的效果一样。转置后 bitrev 依然是 bitrev,但是注意到 bitrev 会在迭代后再进行,而 IDFT 开头也有个 bitrev,那么两个 bitrev 互相抵消,可以直接省略这个 bitrev,从而减少常数。
我将自己本来就跑的不是很快的 NTT 转置后,在洛谷上测试 \(10^6\) 的多项式乘法,每个点快了 150~ 200 ms.
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<vector>
#include<bitset>
#define pb emplace_back
#define mp std::make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
typedef std::pair<int, int> pii;
typedef std::vector<int> vi;
const ll mod = 998244353;
ll Add(ll x, ll y) { return (x+y>=mod) ? (x+y-mod) : (x+y); }
ll Mul(ll x, ll y) { return x * y % mod; }
ll Mod(ll x) { return x < 0 ? (x + mod) : (x >= mod ? (x-mod) : x); }
ll cadd(ll &x, ll y) { return x = (x+y>=mod) ? (x+y-mod) : (x+y); }
ll cmul(ll &x, ll y) { return x = x * y % mod; }
template <typename T> T Max(T x, T y) { return x > y ? x : y; }
template<typename T, typename... T2> T Max(T x, T2 ...y) { return Max(x, y...); }
template <typename T> T Min(T x, T y) { return x < y ? x : y; }
template<typename T, typename... T2> T Min(T x, T2 ...y) { return Min(x, y...); }
template <typename T> T cmax(T &x, T y) { return x = x > y ? x : y; }
template <typename T> T cmin(T &x, T y) { return x = x < y ? x : y; }
template <typename T>
T &read(T &r) {
r = 0; bool w = 0; char ch = getchar();
while(ch < '0' || ch > '9') w = ch == '-' ? 1 : 0, ch = getchar();
while(ch >= '0' && ch <= '9') r = r * 10 + (ch ^ 48), ch = getchar();
return r = w ? -r : r;
}
template<typename T1, typename... T2>
void read(T1 &x, T2& ...y) { read(x); read(y...); }
ll qpow(ll x, ll y) {
ll s = 1;
while(y) {
if(y & 1) s = s * x % mod;
x = x * x % mod;
y >>= 1;
}
return s;
}
const int N = 4000010;
ll *getw(int n, int type) {
static ll w[N/2];
w[0] = 1; w[1] = qpow(type == 1 ? 3 : 332748118, (mod-1)/n);
for(int i = 2; i < n/2; ++i) w[i] = w[i-1] * w[1] % mod;
return w;
}
void DFT(ll *a, int n) { //转置
for(int i = n/2; i; i >>= 1) {
ll *w = getw(i << 1, 1);
for(int j = 0; j < n; j += i << 1) {
ll *b = a + j, *c = b + i;
for(int k = 0; k < i; ++k) {
ll u = b[k], v = c[k];
b[k] = (u + v) % mod;
c[k] = Add((u * w[k]) % mod, Mod(-v * w[k] % mod));
}
}
}
}
void IDFT(ll *a, int n) {
for(int i = 1; i < n; i <<= 1) {
ll *w = getw(i << 1, -1);
for(int j = 0; j < n; j += i << 1) {
ll *b = a + j, *c = b + i;
for(int k = 0; k < i; ++k) {
ll v = c[k] * w[k] % mod;
c[k] = Add(b[k], Mod(-v));
cadd(b[k], v);
}
}
}
ll inv = qpow(n, mod-2);
for(int i = 0; i < n; ++i) a[i] = a[i] * inv % mod;
}
int n, m, len = 1, ct;
ll f[N], g[N];
signed main() {
read(n); read(m);
for(int i = 0; i <= n; ++i) read(f[i]);
for(int i = 0; i <= m; ++i) read(g[i]);
while(len <= n+m) len <<= 1, ++ct;
DFT(f, len);
DFT(g, len);
for(int i = 0; i < len; ++i) f[i] = f[i] * g[i] % mod;
IDFT(f, len);
for(int i = 0; i <= n+m; ++i) printf("%lld ", f[i]);
return 0;
}
Do Use FFT, GYM102978D
给定长为 \(N\) 的序列 \(A, B, C\),对 \(k = 1,\cdots, N\) 求出
\(N\leq 2.5\times 10^5\),模 \(998244353\).
首先搞清楚谁是“输入”,这里只有 \(C\) 作为变量时是仅有一次项,所以将 \(C\) 看作输入,假设得到的答案为 \(q_1,q_2,\cdots ,q_N\),那么变换矩阵即为 \(M_{i,j}=\prod _{k\leq i}(A_j+B_k)\),将其转置得到:
注意到当 \(i\) 改变的时候仅有 \(A_i\) 这里会改变,所以将其看作一个元 \(x\),则有:
欲求 \(F\),分治 FFT,分治到 \([l,r]\) 的时候维护 \(\sum_{l\leq j\leq r}\left(q_j\prod_{l\leq k\leq j}(x+B_k) \right)\) 和 \(\prod_{l\leq k\leq r}(x+B_k)\) 即可,然后再对 \(F\) 多点求值即可得到答案。
考虑完转置问题的解决,来考虑原问题,那么将转置问题的算法转置过来即可。时间复杂度 \(\mathcal{O}(n\log^2 n)\)
因为不会多点求值,就不实现了。
多点求值
待学罢/ll/ll/dk