拉格朗日插值与多项式乘法

进军多项式。

1. 拉格朗日插值

1.1. 普通插值

首先给出公式:

\[F(x)=\sum_{k=1}^n\left(y_k\prod_{i=1,i\neq k}^n \dfrac{x-x_i}{x_k-x_i}\right) \]

解释:对于每对点值 \((x_k,y_k)\),我们需要构造出一个函数 \(G(x)\),使得其在 \(x=x_k\) 处的取值为 \(y_k\),其余处取值为 \(0\)

首先构造函数 \(D(x)=\prod_{i=1,i\neq k}^n x-x_i\)。显然当 \(i\neq k\) 时,有 \(D(x_i)=0\)。但是现在我们不能保证 \(D(x_k)=y_k\)。为了使 \(D(x_k)=y_k\),我们只需要先将其除以 \(D(x_k)\),再乘以 \(y_k\) 即可,这就有了上面的拉格朗日插值公式

通常情况下,题目会要求我们求出 \(F(x)\) 在给定某个 \(x\) 处的取值,此时我们不把 \(x\) 看做函数的一个元,而是直接将 \(x\) 带入上式即可。时间复杂度为 \(\mathcal{O}(n^2)\),代码见例题 I。

1.2. 连续取值插值

很多情况下,我们求出的点值 \(x_i\) 满足 \(x_i=i\),即 \(x_i\) 是连续的。此时我们重新写一下公式:

\[\sum_{k=1}^n\left(y_k\prod_{i=1,i\neq k}^n \dfrac{x-i}{k-i}\right) \]

\(p_i=\prod_{j=1}^ix-i\)\(s_i=\prod_{j=i+1}^n x-i\),这些可以线性预处理,那么上述柿子右边就变成了

\[\dfrac{p_{k-1}s_{k+1}}{(k-1)!\times (-1)^{n-k}(n-k)!} \]

预处理阶乘就可以线性插值,应用见例题 II。

1.3. 求 \(F(x)\) 各项系数

构造函数 \(D(x)=\prod_{i=1}^n(x-x_i)\),设 \(d_k=y_i\prod_{i=1,i\neq k}^n\dfrac{1}{x_k-x_i}\),则 \(F(x)=\sum_{k=1}^n \dfrac{d_kD(x)}{x-x_k}\)。注意到 \(D(x)\) 的各项系数可以在 \(n^2\) 的时间内暴力处理出来,而对于每个 \(k\),我们可以线性将 \(D(x)\) 除以一个一次多项式。最后加和即可。时间复杂度 \(\mathcal{O}(n^2)\)

1.4. 例题

I. P4781 【模板】拉格朗日插值

板子题。

#include <bits/stdc++.h>
using namespace std;

#define ll long long

const ll mod=998244353;
const int N=2e3+5;

ll ksm(ll a,ll b){
	ll s=1;
	while(b){
		if(b&1)s=s*a%mod;
		a=a*a%mod,b>>=1;
	} return s;
}

int n,k,x[N],y[N],ans;
int main(){
	cin>>n>>k;
	for(int i=1;i<=n;i++)cin>>x[i]>>y[i];
	for(ll i=1,s1=1,s2=1;i<=n;i++,s1=s2=1){
		for(int j=1;j<=n;j++)if(i!=j)s1=s1*(k-x[j])%mod,s2=s2*(x[i]-x[j])%mod;
		ans=(ans+y[i]*s1%mod*ksm(s2,mod-2))%mod;
	} cout<<(ans%mod+mod)%mod<<endl;
	return 0;
}

II. CF622F The Sum of the k-th Powers

经典题。一个结论是自然数 \(k\) 次方和是 \(k+1\) 次多项式,那么只需要带 \((i,i^k)\ (i\in [0,k+1])\) 插值即可。注意到当 \(i=0\) 时对答案无贡献,所以在插值的时候可以跳过(但是预处理 \(p,s\) 的时候仍应考虑 \(i=0\)\(n-i\) 的这个 \(n\))。

时间复杂度 \(\mathcal{O}(k\log k)\),使用线性筛筛 \(i^k\) 可以做到线性。

#include <bits/stdc++.h>
using namespace std;

#define ll long long

const ll mod=1e9+7;
const int N=1e6+5;

ll ksm(ll a,ll b){
	ll s=1;
	while(b){
		if(b&1)s=s*a%mod;
		a=a*a%mod,b>>=1;
	} return s;
} ll inv(ll x){return ksm(x,mod-2);}

ll n,k,ans;
ll p[N],s[N],fc[N];
int main(){
	cin>>n>>k,s[k+2]=1;
	for(int i=0;i<=k+1;i++)fc[i]=i?fc[i-1]*i%mod:1,p[i]=i?p[i-1]*(n-i)%mod:n;
	for(int i=k+1;~i;i--)s[i]=i==k+1?n-i:s[i+1]*(n-i)%mod;
	for(ll i=1,res=0;i<=k+1;i++){
		res=(res+ksm(i,k))%mod;
		ans=(ans+p[i-1]*s[i+1]%mod*res%mod*ksm(fc[i]*((k-i)&1?1:-1)*fc[k+1-i]%mod,mod-2))%mod;
	} cout<<(ans%mod+mod)%mod<<endl;
	return 0;
}

III. P4463 [集训队互测 2012] calc

不妨设 \(a_i<a_{i+1}\ (1\leq i<n)\),那么只需将答案乘 \(n!\) 即可。

首先把转移方程写出来:\(f_{i,j}\) 表示值域落在 \([1,i]\) 且长度为 \(j\) 时的答案(即 \(k=i\)\(n=j\))。显然有 \(f_{i,j}=f_{i-1,j}+[j>0]f_{i-1,j-1}\times i\)

  • 接下来我们分析一下 \(f_{i,j}\) 的次数:

    首先我们有 \(f_{i,0}\) 是关于 \(i\)\(0\) 次多项式。接下来使用一个小 Trick。

    Trick:差分分析次数。

    根据转移方程 \(f_{i,j}=f_{i-1,j}+f_{i-1,j-1}\times i\),有 \(f_{i,j}-f_{i-1,j}=f_{i-1,j-1}\times i\)。由于将两个相差 \(1\) 的数带入一个 \(c\) 次多项式得到的差是 \(c-1\) 次多项式,我们有 \(f_{i,j}\) 的次数 \(-1\) 等于 \(f_{i,j-1}\) 的次数 \(+1\),因此 \(f_{i,j}\) 是关于 \(i\)\(2j\) 次多项式

    所以拉格朗日插值即可。

const int N=1.5e3+5;

ll f[N][N],k,n,p,ans;
ll ksm(ll a,ll b){
	ll s=1;
	while(b){
		if(b&1)s=s*a%p;
		a=a*a%p,b>>=1;
	}
	return s;
}
int main(){
	cin>>k>>n>>p,f[0][0]=1;
	for(ll i=1;i<=n*3+4;i++)
		for(ll j=0;j<=min(i,n);j++)
			f[i][j]=(f[i-1][j]+(j?f[i-1][j-1]*i:0))%p;
	for(ll i=n;i<=n*3+3;i++){
		ll s1=f[i][n],s2=1;
		for(ll j=n;j<=n*3+3;j++)if(i!=j)
			s1=s1*(k+p-j)%p,
			s2=s2*(i+p-j)%p;
		ans=(ans+s1*ksm(s2,p-2))%p;
	}
	for(ll i=1;i<=n;i++)ans=ans*i%p;
	cout<<ans<<endl;
	return 0;
}

*IV. [BZOJ2137]submultiple

题意简述:对于 \(m\) 的所有约数 \(d\),求 \(\sum \mathrm{d}^k(d)\)\(m\) 由唯一分解给出。

对于 \(45\%\) 的数据,\(k\leq 2^{31}-1\)\(p\leq 10^5\);对于剩下 \(55\%\) 的数据,\(k\leq 12\)\(p\leq 2^{63}-1\)

\(n\leq 10^5\)\(n\) 表示 \(m\) 由前 \(n\) 小的质数组成。

一个显然的结论是我们只关心唯一分解后每个质因子的次数 \(p_i\),而底数是什么根本不重要:约数个数仅与质因子次数有关。考虑枚举 \(d\) 每个质因子的次数 \(q_i\in [0,p_i]\),得到如下柿子:

\[\sum\prod_{q_1=0}^{p_1}\prod_{q_2=0}^{p_2}\cdots\prod_{q_c=0}^{p_n}\prod_{j=1}^{n}(q_j+1)^k \]

根据加法对乘法的分配律,稍作化简:

\[\prod\left(\sum_{q_1=1}^{p_1+1}q_1^k\right)\left(\sum_{q_2=1}^{p_2+1}q_2^k\right)\cdots \]

\(S_k(n)\) 表示 \(\sum_{i=1}^n i^k\),则题目就是在求 \(\prod_{i=1}^nS_k(p_i+1)\)。一个经典结论是自然数的 \(k\) 次方前缀和是 \(k+1\) 次多项式,因此对于 \(k\) 很大,\(p\) 很小的部分直接 \(\mathcal{O}(p\log k)\) 预处理所有 \(S_k(i)\),而 \(k\) 很小,\(p\) 很大的部分直接单次 \(\mathcal{O}(k^2)\)\(\mathcal{O}(k)\)(连续取值)插值即可。

综上,时间复杂度为 \(\mathcal{O}(\min(n+p\log k,nk^2))\)

template <class T> void cmin(T &a, T b){a = a < b ? a : b;}
template <class T> void cmax(T &a, T b){a = a > b ? a : b;}
bool Mbe;

const int N = 1e5 + 5;
const ll mod = 1e9 + 7;

ll ksm(ll a, ll b = mod - 2) {
	ll s = 1;
	while(b) {
		if(b & 1) s = s * a % mod;
		a = a * a % mod, b >>= 1;
	} return s;
}
ll n, k, type = 1, pw[N];

void solve1() {
	static ll f[N], ans = 1; f[0] = 0;
	for(int i = 1; i < N; i++)
		f[i] = (f[i - 1] + ksm(i, k)) % mod;
	for(int i = 1; i <= n; i++)
		ans = ans * f[pw[i] + 1] % mod;
	cout << ans << endl;
}

void solve2() {
	static ll f[N], ans = 1; f[0] = 0;
	for(int i = 1; i <= k + 1; i++) f[i] = (f[i - 1] + ksm(i, k)) % mod;
	for(int i = 1; i <= n; i++) pw[i] %= mod;
	for(int i = 1; i <= n; i++) {
		ll val = 0;
		for(int j = 0; j <= k + 1; j++) {
			ll nume = f[j], deno = 1;
			for(int p = 0; p <= k + 1; p++) if(p != j)
				deno = deno * (mod + j - p) % mod,
				nume = nume * (pw[i] + 1 + mod - p) % mod;
			val = (val + nume * ksm(deno)) % mod;
		}
		ans = ans * val % mod;
	} cout << ans << endl;
} 

bool Med;
int main(){
	fprintf(stderr, "%.2lf\n", (&Mbe - &Med) / 1048576.0);
	cin >> n >> k;
	for(int i = 1; i <= n; i++) pw[i] = read(), type &= pw[i] <= 1e5;
	if(type) solve1();
	else solve2();
	return 0;
}

*V. [BZOJ3453]tyvj 1858 XLkxc

题意简述:设 \(f_k(n)\) 表示 \(\sum_{i=1}^ni^k\)\(g_k(n)\) 表示 \(\sum_{i=1}^n f_k(i)\)。给定 \(k,a,n,d\),求 \(\sum_{i=0}^{n}g_{k}(a+id)\bmod 1234567891\ (p)\)

\(k\leq 123\)\(a,n,d\leq 123456789\)

感觉数据范围像是出题人随手打的(大雾。

经典结论:\(n^k\) 是关于 \(n\)\(k\) 次多项式,其前缀和(积分)\(f_k(n)\) 是关于 \(n\)\(k+1\) 次多项式,\(f_{k}(n)\) 的积分 \(g_{k}(n)\) 是关于 \(n\)\(k+2\) 次多项式,次数随着前缀和次数是线性增长的,所以再套几层娃也没关系。

考虑只算一个 \(g_k(a)\) 怎么做:直接拉格朗日插值即可,把插值公式写出来:

\[\begin{aligned}&\sum_{p=0}^{n}g_{k}(a+pd)\\=&\sum_{p=0}^n\sum_{i=1}^{k+2}\left(f_{k}(i)\prod_{j\in [1,k+2]\land i\neq j}\dfrac{a+pd-j}{i-j}\right)\\=&\sum_{i=1}^{k+2}f_k(i)\sum_{p=0}^{n}\prod_{j\in [1,k+2]\land i\neq j}\dfrac{a+pd-j}{i-j}&(交换求和符号)\end{aligned} \]

如果我们把后面那个 \(\prod\) 看做一个关于 \(p\)\(k+1\) 次多项式 \(h_i(p)\),那么柿子可以写作:

\[\sum_{i=1}^{k+2}f_k(i)\sum_{p=0}^{n}h_i(p) \]

不慌,对 \(h_i\) 做一遍前缀和,记作 \(H_i\),这是一个关于 \(n\)\(k+2\) 次多项式,可以继续插值插出来(取 \(0\sim k+2\) 处的点值)。即求:

\[\begin{aligned}&\sum_{i=1}^{k+2}f_k(i)H_i(n)\\=&\sum_{i=1}^{k+2}f_k(i)\sum_{j=0}^{k+2}\left(h_i(j)\prod_{j'\in [0,k+2]\land j\neq j'}\dfrac{n-j'}{j-j'}\right)\end{aligned} \]

其中 \(h_i(j)\) 是一个插值形式的柿子。这叫 插 值 套 差 值。注意到两次插值都是连续取值插值,因此可以做到 \(\mathcal{O}(k)\) 插值,总时间复杂度 \(\mathcal{O}(k^2)\),比 tzc 直接求系数 + 二项式展开不知道简单到哪里去了(大雾。

代码偷懒使用了 \(k^2\) 插值。

template <class T> void cmin(T &a, T b){a = a < b ? a : b;}
template <class T> void cmax(T &a, T b){a = a > b ? a : b;}
bool Mbe;

const int N = 1e3 + 5;
const ll mod = 1234567891;

ll ksm(ll a, ll b = mod - 2) {
	ll s = 1;
	while(b) {
		if(b & 1) s = s * a % mod;
		a = a * a % mod, b >>= 1;
	} return s;
}

ll k, a, n, d, c, y[N];
void solve() {
	ll ans = 0;
	cin >> k >> a >> n >> d, c = k + 3;
	for(int i = 1; i <= c; i++) y[i] = (y[i - 1] + ksm(i, k)) % mod;
	for(int i = 1; i <= c; i++) y[i] = (y[i - 1] + y[i]) % mod;
	for(int i = 1; i <= c; i++) {
		ll coef = 0;
		static ll y2[N];
		for(int j = 0; j <= c + 2; j++) {
			y2[j] = j ? y2[j - 1] : 0;
			ll deno = 1, nume = 1;
			for(int p = 1; p <= c; p++) if(i != p)
				nume = nume * ((a + j * d) % mod + mod - p) % mod,
				deno = deno * (i + mod - p) % mod;
			y2[j] = (y2[j] + nume * ksm(deno)) % mod;
		} // 第一遍插值插出 h_i
		for(int j = 1; j <= c + 2; j++) {
			ll deno = 1, nume = y2[j];
			for(int p = 1; p <= c + 2; p++) if(j != p)
				nume = nume * (n + mod - p) % mod,
				deno = deno * (j + mod - p) % mod;
			coef = (coef + nume * ksm(deno)) % mod;
		} // 第二遍插值插出 H_i
		ans = (ans + coef * y[i]) % mod;
	} cout << ans << endl;
}

bool Med;
signed main(){
	fprintf(stderr, "%.2lf\n", (&Mbe - &Med) / 1048576.0);
	int T; cin >> T;
	while(T--) solve();
	return 0;
}

2. 多项式乘法:加法卷积

2.1. FFT

对于一个 \(n\) 次多项式 \(F(x)=a_0+a_1x+a_2x^2+\cdots+a_nx^n\),我们可以用 \(n+1\) 个点值 \((x_k,y_k)\) 唯一确定该多项式。即 \(y_k=\sum_{i=0}^na_ix_k^i\)(注意下文的 \(n\) 表示不小于多项式次数 \(+1\) 的最小的 \(2\) 的幂)。设 \(A=F\times G\),注意到 \(A(x)=F(x)G(x)\),其中 \(x\) 是一个确定的值。因此,我们只需要将一个多项式快速(\(n\log n\))转点值,再快速转成系数表示,就可以做到时间复杂度 \(n\log n\) 的多项式乘法。

点值的取法很有讲究,高明的方法能够极大化地减小时间复杂度。这里我们采用 \(n\) 次单位根 \(\omega_n\),并主要利用以下性质:

  • \(\omega_n^k=\omega_{2n}^{2k}\)
  • \(\omega_n^k=-\omega_n^{k+n/2}\)
  • \(\omega_n^n=1\)
  • \(\omega_n=\cos(\dfrac{2\pi}{n})+i\sin(\dfrac{2\pi}{n})\),从而计算 \(n\) 次单位根。

\(n=4\) 时,\(F(x)=a_0+a_1x+a_2x^2+a_3x^3=(a_0+a_2x^2)+x(a_1+a_3x^2)\)。设 \(L(x)=a_0+a_2x\)\(R(x)=a_1+a_3x\),那么 \(F(x)=L(x^2)+xR(x^2)\)。将单位根 \(\omega_n\) 带入,那么当 \(k<\dfrac{n}{2}\) 时,\(F(\omega_n^k)=L(\omega_{n/2}^k)+\omega_{n}^kR(\omega_{n/2}^k)\)\(F(\omega_n^{k+n/2})=L(\omega_{n/2}^k)-\omega_n^kR(\omega_{n/2}^k)\)。注意到两式只有一个正负号不同。因此,如果我们已经知道了 \(L,R\)\(\omega_{n/2}^i,\ i\in[0,\dfrac{n}{2})\) 处的取值,那么我们就可以在线性时间内求出 \(F\)\(\omega_n^i, \ i\in[0,n)\) 处的取值。考虑递归树的每一层都是线性的,因此总复杂度为 \(n\log n\)

递归处理很慢,于是我们使用迭代:考虑系数 \(a_i\)\(L\) 分治为 \(0\),向 \(R\) 分治为 \(1\),那么显然 \(a_i\) 最终形成的 “分治序列” 形成的二进制数倒过来就是 \(i\) 的二进制表示。考虑求出 \(r_i\) 表示 \(i\) 二进制翻转后得到的数。假设 \(r_0,r_1,\cdots,r_{i-1}\) 都已经求出 ,那么有 \(r_i=\lfloor\dfrac{r_{\lfloor\dfrac{i}{2}\rfloor}}{2}\rfloor+\dfrac{n}{2}\times(i\bmod 2)\)。左边是 \(i\) 不考虑最低位(即假设其为 \(0\))时二进制翻转得到的数,然后再考虑 \(i\) 的最低位的影响即可。然后对于每一对无序对 \((i,r_i)\),将 \(a_i\)\(a_{r_i}\) 交换,那么最终的分治树的形态类似线段树,直接从最底层向上迭代即可。该操作被称为蝴蝶迭代


卡常技巧

  • 结构体不要写构造函数。
  • 重载加减乘运算符放到 struct 里面。

2.2. NTT

由于复数运算很慢,而通常情况下我们是在模意义下进行多项式运算,所以当模数取一些特殊值时,我们可以用 \(\mathbb{Z}\) 中的数 \(g\) 代替单位根的复数运算。具体地,\(g\) 需要是模 \(p\) 的原根。

鸽着。

2.3. FFT 优化字符串匹配

Trick 1:翻转一个序列常常可以使关于它的某些计算变成卷积的形式。

对于一个文本串 \(s\) 与匹配串 \(t\)(下标从 \(0\) 开始),设它们的长度分别为 \(n\)\(m\)。称它们在位置 \(p\) 匹配,当且仅当对于任意 \(i\in[0,m)\),有 \(s_{p-m+i+1}=t_i\)。不难发现它的充分条件为 \(\sum_{i=0}^{m-1}(s_{p-m+i+1}-t_i)^2=0\)。展开,得到 \(\sum_{i=0}^{m-1}(s^2_{p-m+i+1}+t^2_{i}-2s_{p-m+i+1}t_i)=0\)。注意到前面两项容易预处理得到,但后面一项同时关于两个字符串,比较麻烦。又发现它的形式类似卷积,但又不是卷积:翻转字符串 \(t\) 即可。因此柿子变为 \(\sum_{i=0}^{m-1}(s^2_{p-m+i+1}+t_{m-i-1}^2-2s_{p-m+i+1}t_{m-i-1})\),后面一项即 \(2\sum_{0\leq i<m,i+j=p}s_jt_i\),FFT 计算即可。

对于有通配符的字符串,我们不妨将该位置上的值设为 \(0\),然后乘到上面的柿子里去,即 \(\sum_{i=0}^{m-1}(s_{p-m+i+1}-t_i)^2s_{p-m+i+1}t_i\)。化简得到 \(\sum_{0\leq i<m,i+j=p}(s^3_jt_i-2s^2_jt^2_i+s_jt^3_i)\),做 6 次 DFT + 1 次 IDFT 即可。

Trick 2:一般情况下,上述方法足够应付多数题目。但是如果万恶的出题人卡了 FFT 精度就凉凉了。因此,为了保险,应尽量使用 NTT。但是 NTT 也有一个致命问题:如果计算出来的值刚好是 \(998244353\) 的倍数,那么就会在不匹配的地方判定匹配。看起来好像没有解决办法了?非也。在 P4173 残缺的字符串 这题的讨论区 https://www.luogu.com.cn/discuss/show/303076,我找到了一个高明的手段解决这个问题:注意到上述柿子最大值可达到 \(3\times m\times (|\mathbb{\Sigma}|-1)^4\),如果 \(m\)\(5\times 10^5\),字符集取大小 \(26\),那么数量级为 \(1.5\times 10^6\times 25^4\approx 6\times 10^{11}\),显然不太行。但是实际上我们没必要将整个 \(s\)\(t\) 乘进去,因为我们关心的只是某一位是否是通配符,而具体这一位是什么并不重要。因此,我们令 \(S_i=[s_i\neq\texttt{*}]\)\(T_i=[t_i\neq\texttt{*}]\),只需要将 \(S,T\) 而非 \(s,t\) 乘入即可。 这时最大值仅为 \(3\times 5\times 10^5\times 25^2=9.375\times 10^8<998244353\),有惊无险地保证了正确性。当然,不同题目还应根据 \(m\) 和字符集大小的不同取值具体分析正确性。

也许看了上述分析的你认为:既然将 \(S,T\) 乘进去,那么不如直接将是通配符的位置当做 \(0\) 来计算 \(\sum_{i=0}^{m-1}(s_{p-m+i+1}-t_i)^2\) 不就好了?非也。因为这样在计算只关于某一个字符串的项时,考虑不到另外一个字符串的对应位置是否是通配符。因此,6 次 DFT 是逃不掉的。最终的柿子即为 \(\sum_{0\leq i<m,i+j=p}((s_j^2S_j)\times T_i-2(s_jS_j)\times (t_iT_i)+S_j\times (t^2_iT_i))\)

例题 & 代码可以看 II.

2.4. 任意模数:MTT

MTT 又称为任意模数 FFT。

2.4.1. 拆系数:7 次 FFT

对于值域为 \(V\) 的两个 \(n\) 次多项式,它们相乘后值域为 \(V^2n\)。对于一般的题目,\(V\approx 10^9\)\(n\approx 10^5\),所以 \(V^2n\approx 10^{23}\),long double 都无法承受。

对此,我们可以将系数拆分成 \(xp+y\) 的形式,其中 \(p\) 一般取平方大于值域的最小的 \(2\) 的幂,这里可以使用 \(2^{15}\)。这样,我们设 \(a=a_0p+a_1\)\(b=b_0p+b_1\),那么 \(c=a\times b=(a_0p+a_1)(b_0p+b_1)=a_0b_0p^2+(a_0b_1+a_1b_0)p+a_1b_1\)。值域约为 \((2^{15})^2\times n\approx 10^{14}\),可以接受。

综上,我们有了一个显然的 4 DFT + 3 IDFT = 7 FFT 的做法,但是不够快。

2.4.2. 优化 Trick:合并 DFT 两个实系数多项式

构造多项式 \(p=a+bi\)\(q=a-bi\),由于 \(a,b\) 的系数都是实数,故 \(p,q\) 对应项系数互为共轭。为了让它们对应位置点值表示的结果也是共轭的,由于积的共轭等于共轭的积,我们需要保证带入的单位根也互相共轭。例如,设 \(p\)\(j\) 项前的系数为 \(c+di\),则 \(q\) 的对应项系数为 \(c-di\)。为了让 \((c+di)\omega_p^j\)\((c-di)\omega_q^j\) 互为共轭,则 \(\omega_p\)\(\omega_q\) 也需互为共轭。

如果我们先对 \(p\) 进行 DFT,也就是算出 \(p(\omega_n^i)\) 的点值,那么根据上面的分析,它\(q(\overline{\omega_n^i})\) 互为共轭。由于 \(\omega_{n}^i\)\(i=0\) 时的共轭为它本身,当 \(i\neq 0\) 时的共轭为 \(\omega_n^{n-i}\),并且 \(p\) 在位置 \(i\) 处求得的点值为 \(p(\omega_n^i)\),因此为了得到 \(q\) 在位置 \(i\) 处的点值 \(q(\omega_n^i)\),我们只需要将 \(p\)\(1\sim n-1\) 项的点值进行序列翻转,再求共轭即可。\(a=\dfrac{p+q}{2}\)\(b=\dfrac{p-q}{2i}=\dfrac{(q-p)i}{2}\)

同时 DFT 两个实系数多项式,可以压缩成一次 DFT。大大减小了常数。

2.4.3. 4 次 FFT

根据 2.4.2. 的 Trick,原来的 4 DFT 就被优化成了 2 DFT。四维变二维。

有了 \(a_0,a_1,b_0,b_1\) 的点值,接下来可以计算 \(a_0b_0,a_0b_1+a_1b_0,a_1b_1\) 的点值。可是 IDFT 回去仍要做三遍 FFT,并不是很优秀。

我们借鉴上述思想,构造多项式 \(p=a_0b_0+a_0b_1i\)。将其进行 IDFT 之后,由于 \(a_0b_0\)\(a_0b_1\) 的系数表示法的每一项系数都为实数,因此 \(p\) 的实部就是 \(a_0b_0\) 的系数表示,\(p\) 的虚部就是 \(a_0b_1\) 的系数表示。一个直观的解释就是:给出 \(3+4i\),你知道是它是由一个实数和一个纯虚数相加得到,那么很容易就能得出原来的两个数分别是 \(3\)\(4i\)。相反,如果你知道它是由一个复数和另一个复数相加得到,那么显然无法还原原来的两个数是什么。而我们所做的就是前者。

对于 \(a_1b_0\)\(a_1b_1\),同理。

综上,我们只需要进行 2 DFT + 2 IDFT = 4 FFT,常数非常优秀。代码见例题 III.

2.5. 任意模数:三模数 NTT

考虑对 \(f(x)\) 在可 NTT 模数下进行 NTT,然后使用中国剩余定理 CRT 合并即可。这个还是比较好理解的,就不细说了。

一般选模数分别为 \(469762049,998244353,1004535809\),这样可以处理到值域为大约 \(4\times 10^8\times 10^9\times 10^9\approx 4\times 10^{26}\) 的多项式乘法,对付一般的题目应该是绰绰有余了。

不过常数稍大,可以手写结构体将对于这三个模数分别取模后的值打包,这样常数会小一些。

2.6. 任意长度:Bluestein 算法

2.7. 例题

I. P3803 【模板】多项式乘法(FFT)

FFT:

#include <bits/stdc++.h>
using namespace std;

#define ld double

const int N=1<<21;
const ld Pi=acos(-1);

struct com{
	ld x,y;
	com operator + (com b){return (com){x+b.x,y+b.y};}
	com operator - (com b){return (com){x-b.x,y-b.y};}
	com operator * (com b){return (com){x*b.x-y*b.y,x*b.y+y*b.x};}
}a[N],b[N];

int n,m,lim=1,bit,r[N];
void FFT(com *a,int tp){
	for(int i=0;i<lim;i++)if(i<r[i])swap(a[i],a[r[i]]);
	for(int l=1;l<lim;l<<=1){
		com wn={cos(Pi/l),tp*sin(Pi/l)};
		for(int j=0;j<lim;j+=(l<<1)){
			com w={1,0},x,y;
			for(int k=0;k<l;k++,w=w*wn)
				x=a[j+k],y=w*a[j+k+l],a[j+k]=x+y,a[j+k+l]=x-y;
		}
	}
}
int main(){
	cin>>n>>m;
	for(int i=0;i<=n;i++)scanf("%lf",&a[i].x);
	for(int i=0;i<=m;i++)scanf("%lf",&b[i].x);
	while(lim<=n+m)lim<<=1,bit++;
	for(int i=1;i<lim;i++)r[i]=(r[i>>1]>>1)|((i&1)<<bit-1);
	FFT(a,1),FFT(b,1);
	for(int i=0;i<lim;i++)a[i]=a[i]*b[i];
	FFT(a,-1);
	for(int i=0;i<=n+m;i++)cout<<(int)(a[i].x/lim+0.5)<<" ";
	return 0;
}

NTT:

#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define ull unsigned long long
#define gc getchar()

inline int read(){
	int x=0; char s=gc;
	while(!isdigit(s))s=gc;
	while(isdigit(s))x=s-'0',s=gc;
	return x;
}

const ll mod=998244353;
const int N=1<<21;

ll ksm(ll a,ll b){
	ll s=1;
	while(b){
		if(b&1)s=s*a%mod;
		a=a*a%mod,b>>=1;
	} return s;
} ll inv(ll x){return ksm(x,mod-2);}

const int G=3;
const int ivG=inv(3);

ll tr[N],lim=1,l;
ll n,m,f[N],g[N];
void NTT(ll *a,bool tp){
	static ull f[N],w[N]; w[0]=1;
	for(int i=0;i<lim;i++)f[i]=a[tr[i]];
	for(int l=1;l<lim;l<<=1){
		ll wn=ksm(tp?G:ivG,(mod-1)/(l+l));
		for(int i=1;i<l;i++)w[i]=w[i-1]*wn%mod;
		for(int i=0;i<lim;i+=l<<1){
			for(int j=0;j<l;j++){
				int y=w[j]*f[i|j|l]%mod;
				f[i|j|l]=f[i|j]+mod-y,f[i|j]+=y;
			}
		} if(l==(1<<17))for(int i=0;i<lim;i++)f[i]%=mod;
	}
	if(!tp){
		ll iv=inv(lim);
		for(int i=0;i<lim;i++)a[i]=f[i]%mod*iv%mod;
	} else for(int i=0;i<lim;i++)a[i]=f[i]%mod;
}
int main(){
	cin>>n>>m;
	for(int i=0;i<=n;i++)f[i]=read();
	for(int i=0;i<=m;i++)g[i]=read();
	while(lim<=n+m)lim<<=1,l++;
	for(int i=1;i<lim;i++)tr[i]=(tr[i>>1]>>1)|((i&1)<<l-1);
	NTT(f,1),NTT(g,1);
	for(int i=0;i<lim;i++)f[i]=f[i]*g[i]%mod;
	NTT(f,0);
	for(int i=0;i<=n+m;i++)printf("%lld ",f[i]);
}

II. P4173 残缺的字符串

经典字符串匹配题。

#include <bits/stdc++.h>
using namespace std;

typedef double db;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;

#define gc getchar()
#define pb push_back
#define mem(x,v,n) memset(x,v,sizeof(int)*n)
#define cpy(x,y,n) memcpy(x,y,sizeof(int)*n)

const ld Pi=acos(-1);
const ll mod=998244353;

inline int read(){
	int x=0; char s=gc;
	while(!isdigit(s))s=gc;
	while(isdigit(s))x=x*10+s-'0',s=gc;
	return x;
}

ll ksm(ll a,ll b){
	ll s=1;
	while(b){
		if(b&1)s=s*a%mod;
		a=a*a%mod,b>>=1;
	}
	return s;
}
ll inv(ll x){return ksm(x,mod-2);}

const int N=1<<19;
const ll G=3;
const ll ivG=inv(3);

int r[N],pren;
void pre(int n){
	if(n==pren)return;
	for(int i=1;i<n;i++)r[i]=(r[i>>1]>>1)|(i&1?n>>1:0);
}
void NTT(int *g,int n,bool op){
	pre(n);
	static ull f[N],w[N]; w[0]=1;
	for(int i=0;i<n;i++)f[i]=g[r[i]];
	for(int l=1;l<n;l<<=1){
		ull wn=ksm(op?G:ivG,(mod-1)/(l+l));
		for(int i=1;i<l;i++)w[i]=w[i-1]*wn%mod;
		for(int i=0;i<n;i+=l<<1)
			for(int j=0;j<l;j++){
				int t=w[j]*f[i|j|l]%mod;
				f[i|j|l]=f[i|j]+mod-t,f[i|j]+=t;
			}
		if(l==(1<<16))for(int i=0;i<n;i++)f[i]%=mod;
	}
	if(op)for(int i=0;i<n;i++)g[i]=f[i]%mod;
	else{
		ll iv=inv(n);
		for(int i=0;i<n;i++)g[i]=f[i]%mod*iv%mod;
	}
}

int n,m,lim=1,a[N],b[N],res[N];
int ans[N],cnt;
string A,B;
int main(){
	cin>>n>>m>>A>>B,reverse(A.begin(),A.end());
	while(lim<m)lim<<=1;
	
	for(int i=0;i<n;i++)if(A[i]!='*')a[i]=(A[i]-'a')*(A[i]-'a');
	for(int i=0;i<m;i++)if(B[i]!='*')b[i]=1;
	NTT(a,lim,1),NTT(b,lim,1);
	for(int i=0;i<lim;i++)res[i]=1ll*a[i]*b[i]%mod;
	
	mem(a,0,lim),mem(b,0,lim);
	for(int i=0;i<n;i++)if(A[i]!='*')a[i]=A[i]-'a';
	for(int i=0;i<m;i++)if(B[i]!='*')b[i]=B[i]-'a';
	NTT(a,lim,1),NTT(b,lim,1);
	for(int i=0;i<lim;i++)res[i]=(res[i]-2ll*a[i]*b[i]%mod+mod)%mod;
	
	mem(a,0,lim),mem(b,0,lim);
	for(int i=0;i<n;i++)if(A[i]!='*')a[i]=1;
	for(int i=0;i<m;i++)if(B[i]!='*')b[i]=(B[i]-'a')*(B[i]-'a');
	NTT(a,lim,1),NTT(b,lim,1);
	for(int i=0;i<lim;i++)res[i]=(res[i]+1ll*a[i]*b[i])%mod;
	
	NTT(res,lim,0);
	for(int i=n-1;i<m;i++)if(!res[i])ans[++cnt]=i-n+2;
	cout<<cnt<<endl;
	for(int i=1;i<=cnt;i++)cout<<ans[i]<<" ";
	
	return 0;
}

III. P4245 【模板】任意模数多项式乘法

const int N=(1<<18)+1;
const int B=32767;

struct cp{
	double re,im;
	cp operator + (cp x){return (cp){re+x.re,im+x.im};}
	cp operator - (cp x){return (cp){re-x.re,im-x.im};}
	cp operator * (cp x){return (cp){re*x.re-im*x.im,re*x.im+im*x.re};}
	cp operator * (double x){return (cp){re*x,im*x};}
	cp operator / (double x){return (cp){re/x,im/x};}
	cp conj (){return (cp){re,-im};}
}I,w[N],a0[N],b0[N],a1[N],b1[N],P[N],Q[N];

int lim=1,r[N];
void init(){
	for(int i=1;i<lim;i++)r[i]=(r[i>>1]>>1)|(i&1?lim>>1:0);
	for(int i=0;i<=lim;i++)w[i]=(cp){cos(2*i*Pi/lim),sin(2*i*Pi/lim)};
}

void FFT(cp *a,int op){
	static cp f[N];
	for(int i=0;i<lim;i++)f[i]=a[r[i]];
	for(int l=1;l<lim;l<<=1)
		for(int i=0,b=lim/l>>1;i<lim;i+=l<<1)
			for(int j=0,c=0;j<l;j++,c+=b){
				cp k=f[i|j|l]*w[~op?c:lim-c];
				f[i|j|l]=f[i|j]-k,f[i|j]=f[i|j]+k;
			}
	for(int i=0;i<lim;i++)a[i]=op==1?f[i]:f[i]/lim;
}
void FFFT(cp *a,cp *b){
	for(int i=0;i<lim;i++)a[i].im=b[i].re; FFT(a,1);
	for(int i=0;i<lim;i++)b[i]=a[i?lim-i:0].conj();
	for(int i=0;i<lim;i++){
		cp p=a[i],q=b[i];
		a[i]=(p+q)*0.5,b[i]=(q-p)*0.5*I;
	}
}

int n,m,p;
int main(){
	cin>>n>>m>>p,I={0,1};
	for(int i=0,v;i<=n;i++)v=read()%p,a0[i].re=v>>15,a1[i].re=v&B;
	for(int i=0,v;i<=m;i++)v=read()%p,b0[i].re=v>>15,b1[i].re=v&B;
	while(lim<=n+m)lim<<=1; init();
	FFFT(a0,a1),FFFT(b0,b1);
	for(int i=0;i<lim;i++){
		P[i]=a0[i]*b0[i]+a0[i]*b1[i]*I,
		Q[i]=a1[i]*b0[i]+a1[i]*b1[i]*I;
	}
	FFT(P,-1),FFT(Q,-1);
	for(int i=0;i<=n+m;i++){
		ll p1=(ll)(P[i].re+0.5)%p+p;
		ll p2=(ll)(P[i].im+Q[i].re+0.5)%p+p;
		ll p3=(ll)(Q[i].im+0.5)%p+p;
		printf("%lld ",((p1<<30)+(p2<<15)+p3)%p);
	}
	return 0;
}

IV. P3723 [AH2017/HNOI2017]礼物

设第 \(1\) 个手环与第 \(2\) 个手环增加亮度的差为 \(c\),那么有 \((x_i+c-y_i)^2=x_i^2+y_i^2+c^2+2x_ic-2y_ic-2x_iy_i\),注意到如果我们枚举 \(c\),那么就要求 \(\sum x_iy_i\) 的最小值。进行翻转后变成卷积形式,再破环成链即倍长数组,注意到值域 \(\leq 5\times 10^8\),使用 NTT 即可。

const ll G=3;
const ll ivG=inv(3);
const int N=1<<18;

int n,m,len,r[N];
void init(int n){for(int i=1;i<n;i++)r[i]=(r[i>>1]>>1)|(i&1?n>>1:0);}
void NTT(ll *a,int n,bool op){
	static ull f[N],w[N]; w[0]=1;
	for(int i=0;i<n;i++)f[i]=a[r[i]];
	for(int l=1;l<n;l<<=1){
		ll wn=ksm(op?G:ivG,(mod-1)/(l+l));
		for(int i=1;i<l;i++)w[i]=w[i-1]*wn%mod;
		for(int i=0;i<n;i+=l+l)
			for(int j=0;j<l;j++){
				int y=f[i|j|l]*w[j]%mod;
				f[i|j|l]=f[i|j]+mod-y,f[i|j]+=y;
			} if(l==(1<<16))for(int i=0;i<n;i++)f[i]%=mod;
	} if(!op){
		ll iv=inv(n);
		for(int i=0;i<n;i++)a[i]=f[i]%mod*iv%mod;
	} else for(int i=0;i<n;i++)a[i]=f[i]%mod;
}

ll a[N],b[N];
ll ssq,ss,res,ans=1e10;
int main(){
	cin>>n>>m;
	for(int i=0;i<n;i++)scanf("%d",&a[i]),ssq+=a[i]*a[i],ss+=a[i];
	for(int i=0;i<n;i++)scanf("%d",&b[i]),ssq+=b[i]*b[i],ss-=b[i],b[i+n]=b[i];
	reverse(a,a+n);
	len=1; while(len<3*n)len<<=1;
	init(len),NTT(a,len,1),NTT(b,len,1);
	for(int i=0;i<len;i++)a[i]=a[i]*b[i]%mod;
	NTT(a,len,0);
	for(int i=n-1;i<2*n;i++)res=max(res,a[i]);
	for(int i=-m;i<=m;i++)ans=min(ans,ssq+2*ss*i+n*i*i-2*res);
	cout<<ans<<endl;
	return 0;
}
posted @ 2021-08-01 15:35  qAlex_Weiq  阅读(2039)  评论(1编辑  收藏  举报