题解 P5282 【模板】快速阶乘算法
总算是学会了这个算法......
【前置芝士】
前两个用于处理任意模数意义下的多项式乘法;
第三个用于在未知一个不超过 \((r-l)\) 次的多项式具体形式,但已知其在某连续区间 \([l,r]\) 的 \((r-l+1)\) 个点值时,求出任一与之不交的、长度相同的区间的 \((r-l+1)\) 个点的值。复杂度为 \(O(n\log n)\)。
【分析】
考虑令 \(s=\lfloor\sqrt n\rfloor\) ,则 \(\displaystyle n!=\prod_{i=1}^n i=\prod_{i=0}^{s-1}\prod_{t=1}^s(is+t)\cdot \prod_{i=s^2+1}^n i\) 。
若已求出 \(\displaystyle \prod_{i=0}^{s-1}\prod_{t=1}^s(is+t)\),则后面的那个式子直接暴力点乘,复杂度为 \(o(\sqrt n)\) 的。
有个比较简单的方法是由分治 FFT 或多项式启发式合并求出 \(s\) 次多项式 \(\displaystyle g_s(x)=\prod_{t=1}^s (x+t)\) ;再由多项式多点求值求出 \(g_s(0\cdot s),g_s(1\cdot s),g_s(2\cdot s),\cdots, g_s((s-1)\cdot s)\) 。这些乘起来即为待求项。
这样一来,复杂度为 \(O(s\log^2 s)+O(s\log^2 s)=O(\sqrt n\log^2 n)\) 。
但考虑到最后实际对答案的贡献仅在若干个点处取到。若这些点所处的区间连续,则可能可用多项式连续点值平移求出答案。
很幸运的是,我们令 \(h_s(x)=g_s(sx)\),则待求的点值变换为了 \(h_s(0),h_s(1),h_s(2),\cdots, h_s(s-1)\) ,是长度为 \(s\) 的连续区间。这就可以用到多项式连续点值平移了。
现在,我们可以考虑如何使用多项式连续点值平移凑出这些东西了。
而初始我们也并不知道 \(h_s(x)\) 的任何一个点值,但是我们知道 \(h_1(x)=x+1\) 的所有点值。
因此,我们考虑如何倍增得到 \(h_s(x)\) 的那些点值。
当我们已知 \(h_d(0),h_d(1),\cdots,h_d(d)\) 时,若我们想要得到 \(h_{2d}(0),h_{2d}(1),\cdots,h_{2d}(2d)\)。
对于 \(\displaystyle \forall i\leq d\to h_{2d}(i)=g_{2d}(is)=\prod_{t=1}^{2d}(is+t)=g_d(is)\cdot g_d(is+d)\) 。
由于 \(g_d(is+d)\) 在模 \(P\) 意义下,与 \(h_d(d\cdot s^{-1}+i)\) 等价,因此 \(h_{2d}(i)=h_d(i)\cdot h_d(i+d\cdot s^{-1})\) 。
求解这些点值时,我们只要将 \(h_d(0),h_d(1),\cdots,h_d(d)\) 这些连续点值平移出 \(h_{2d}(0+d\cdot s^{-1}),h_d(1+d\cdot s^{-1}),\cdots, h_d(d+d\cdot s^{-1})\) ,直接两两点乘即可倍增。
而对于后面的那几个点,如果我们已知 \(h_d(d+1),h_d(d+2),\cdots, h_d(2d)\) ,就也能通过这个方法求解。
那怎么知道呢?问就是多项式连续点值平移,用 \(h_d(0),h_d(1),\cdots,h_d(d)\) 直接平移出 \(h_d(d+1),h_d(d+2),\cdots,h_d(2d+1)\) 就可以了(这里因为要求区间不交,会额外平移出 \(h_d(2d+1)\))。
既然这样,那也别客气了,分那么清楚, 第一次先平移出 \(h_d(d+1),h_d(d+2),\cdots, h_d(2d+1)\) ;拼接到原点值后面之后,第二次再直接用 \(h_d(0),h_d(1),\cdots,h_d(2d+1)\) 直接平移出 \(h_d(0+d\cdot s^{-1}),h_d(1+d\cdot s^{-1}),\cdots,h_d(2d+1+d\cdot s^{-1})\) 。点乘则可以得到我们要求的 \(h_{2d}(0),h_{2d}(1),\cdots,h_{2d}(2d)\) 。复杂度为 \(O(d\log d)+O(2d\log 2d)+O(2d)=O(d\log d)\) 。
由于我们可以从 \(h_d\) 的状态推出 \(h_{2d}\) 的状态,我们又已知 \(h_1\) 的状态,待求 \(h_s\) 的状态;于是我们去考虑用类似递归快速幂的迭代方法求出结果。
那还差的步骤就是由 \(h_d\) 的状态推出 \(h_{d+1}\) 的状态。
这个简单, 由于 \(\displaystyle h_{d+1}(i)=g_{d+1}(is)=\prod_{t=1}^{d+1}(is+t)=g_d(is)\cdot (is+d+1)=h_d(i)\cdot (is+d+1)\) ,我们直接 \(O(d)\) 暴力跑过去就行了。
但是由 \(h_d\) 推 \(h_{d+1}\) 时,需要额外知道 \(h_{d+1}(d+1)\) 的信息,它需要用到 \(h_d(d+1)\) 的信息。
但这个无伤大雅,因为 \(h_d\) 推出 \(h_{d+1}\) 的状态一定发生在 \(h_{d\over 2}\) 推出 \(h_d\) 之后。而像上面说的,正好由于区间不交,恰好在 \(h_d\) 推出 \(h_{2d}\) 时,把 \(h_{2d}(2d+1)\) 的给顺便算了。
那就好办了,如果接着要 \(h_{2d}\) 推 \(h_{2d+1}\) ,就把这个正好算了;如果不要,就顺手丢了。于是这一步的复杂度是严格 \(O(d)\) 的。
综上,我们由已知的 \(h_1\) 状态,倍增地推出 \(h_s\) 的状态。最后累乘 \(h_s(0),\cdots, h_s(s-1)\) ,再乘上那些尾巴,就可以算出总答案了。
复杂度 \(\displaystyle T(s)=T({s\over 2})+O(s\log s)+O(s)\) 得到 \(T(s)=O(s\log s)=O(\sqrt n\log n)\) 。(原本还要加上一个 \(o(\sqrt n)\) ,但是作为低阶项舍去了。)
【核心代码】
poly c, d;
inline int get_fac(int n) {
int pos=curStk, s=sqrt(n)+1e-6;
vir invs=Inv[s];
for(int i=s;i>1;i>>=1) Stk[++curStk]=i;//手动压栈
c.resize(2); c[0]=1; c[1]=s+1;//初始化 h_1 的状态
for(int l=Stk[curStk]; curStk>pos; l=Stk[--curStk]) {
ValueTrans(c, d, l>>1, (l>>1)+1);//点值平移出额外状态
c.resize(2*sz(c));
for(int i=0; i<sz(d); ++i) c[sz(d)+i]=d[i];
ValueTrans(c, d, sz(c)-1, invs*vir(l>>1));//点值平移出点乘状态
for(int i=0; i<sz(c); ++i) c[i]=c[i]*d[i];
if(l&1) {
for(int i=0; i<=l; ++i) c[i]=c[i]*vir(i*s+l);//h_l 转为 h_{l+1}
}
else c.resize(l+1);
}
vir res=1;
for(int i=0; i<s; ++i) res=res*c[i];//累乘
for(int i=s*s+1; i<=n; ++i) res=res*vir(i);//处理尾巴
return res;
}
【拓展】
你以为这就完了?
我发现我的 代码 跑得飞快,看到连时限 4s 的十分之一都没跑到,我不禁“恶向胆边生”。用这个板子改了改,测试了一下阶乘求和的结果。
做法是求出 \(\displaystyle \sum_{i=1}^n i!\) 和 \(\displaystyle \sum_{i=1}^{n-1} i!\),做差算出答案。
那这个问题怎么用点值平移处理呢?
我们记 \(\displaystyle S_n=\sum_{i=1}^n i!\) ,考虑它的转移矩阵:
\(\begin{aligned} \begin{pmatrix} (n+1)! \\S_n \end{pmatrix}&= \begin{pmatrix} n+1&0 \\1&1 \end{pmatrix}\cdot \begin{pmatrix} n! \\S_{n-1} \end{pmatrix} \\&=\prod_{i=1}^n \begin{pmatrix} i+1&0 \\1&1 \end{pmatrix}\cdot \begin{pmatrix} 1 \\0 \end{pmatrix} \end{aligned}\)
同理,我们考虑 \(\displaystyle g_d(x)=\prod_{i=1}^d\begin{pmatrix} x+i+1&0 \\1&1 \end{pmatrix},h_d(i)=g_d(is)\) 。
手玩一下会发现,这个矩阵这个矩阵的乘法也有特点:
\(\begin{pmatrix} a_1&\ \\b_1&c_1 \end{pmatrix}\cdot\begin{pmatrix} a_2&\ \\b_2&c_2 \end{pmatrix}=\begin{pmatrix} a_1a_2&\ \\b_1a_2+c_1b_2&c_1c_2 \end{pmatrix}\)
也就是说,这个矩阵的乘积永远只有三个地方有值,并且右下角恒为 \(1\) 。于是我们维护 \(h_d(x)=\begin{pmatrix}mc_{d,0}(x)&\ \\mc_{d,1}(x)&1\end{pmatrix}\) 。
由原来维护一个点值变成了维护两个点值 而已 。
由于此时为矩阵乘法 \(h_{2d}(i)=h_d(i+d\cdot s^{-1})\cdot h_d(i)\) 形式不会发生变化,但一定要注意矩阵的左右乘不等价。
发生变化的是 \(h_d\) 转移到 \(h_{d+1}\) 时,此时有:
\(\begin{aligned} &h_{d+1}(i) \\=&\begin{pmatrix}mc_{d+1,0}(i)&\ \\mc_{d+1,1}(i)&1\end{pmatrix} \\=&\begin{pmatrix}i+d+2&0\\1&1\end{pmatrix}h_d(i) \\=&\begin{pmatrix}i+d+2&0\\1&1\end{pmatrix}\begin{pmatrix}mc_{d,0}(x)&\ \\mc_{d,1}(x)&1\end{pmatrix} \\=&\begin{pmatrix}(i+d+2)mc_{d,0}(x)&\ \\mc_{d,0}+mc_{d,1}(x)&1\end{pmatrix} \end{aligned}\)
修改此处即可。
poly mc[2], md[2];
inline int get_sumfac(int n) {
int pos=curStk, s=sqrt(n)+1e-6;
vir invs=Inv[s];
for(int i=s;i>1;i>>=1) Stk[++curStk]=i;
mc[0].resize(2); mc[0][0]=2; mc[0][1]=s+2;
mc[1].resize(2); mc[1][0]=mc[1][1]=1;
for(int l=Stk[curStk]; curStk>pos; l=Stk[--curStk]) {
for(int t=0; t<2; ++t){
poly &c = mc[t], &d = md[t];
ValueTrans(c, d, l>>1, (l>>1)+1);
c.resize(2*sz(c));
for(int i=0; i<sz(d); ++i) c[sz(d)+i]=d[i];
ValueTrans(c, d, sz(c)-1, invs*vir(l>>1));
}
for(int i=0; i<sz(mc[0]); ++i) {
mc[1][i]=mc[0][i]*md[1][i]+mc[1][i];
mc[0][i]=mc[0][i]*md[0][i];
}
if(l&1) {
for(int i=0; i<=l; ++i) {
mc[1][i]=mc[0][i]+mc[1][i];
mc[0][i]=mc[0][i]*vir(i*s+l+1);
}
}
else {
mc[0].resize(l+1);
mc[1].resize(l+1);
}
}
vir res0=1, res1=0;
for(int i=0; i<s; ++i){
res1=res0*mc[1][i]+res1;
res0=res0*mc[0][i];
}
for(int i=s*s+1; i<=n; ++i){
res1=res1+res0;
res0=res0*vir(i+1);
}
return res1;
}
结果 还是没跑到一半的时间。