题解[loj #6682. 梦中的数论]

原题链接

题意:求 \(\sum\limits_{i=1}^{n}\sum\limits_{j=1}^{n}\sum\limits_{k=1}^{n}[j|i][(j+k)|i],n\leq 10^{10}\)

\(f(i)=\sum\limits_{j=1}^{i}\sum\limits_{k=1}^{i}[j|i][(j+k)|i]\) ,则原式为 \(f(i)\) 的前缀和。

\([j|i][(j+k)|i]\) 相当于从 \(i\) 的因数中选出两个的方案数,为 \(\dbinom{\sigma_0(i)}{2}\)

所以 \(f(i)=\dbinom{\sigma_0(i)}{2}=\dfrac{\sigma_o(i)^2-\sigma_0(i)}{2}\)

其中 \(\sigma_0(i)\) 的前缀和是易得的:\(\sum\limits_{i=1}^{n}\sum\limits_{d|i}=\sum\limits_{d=1}^{n}\left\lfloor\dfrac{n}{d}\right\rfloor\) 一次数论分块即可。

重点看 \(\sigma_0(i)^2\) 的前缀和,这个积性函数不得不重推 \(\text{dgf}\) 来探究。

为了下文表述方便,设 \(f_p(x)=\sum\limits_{k\geq 0}\sigma(p^k)^2x^k\) ,即设对每个质数 \(p\)\(x=p^{-z}\) 转为贝尔级数的形式。

\(\sigma_0(i)^2\)\(\text{dgf}\) 就是 \(\prod\limits_{p\in \text{Prime}}f_p(p^{-z})\)

由于 \(\sigma_0(p^k)^2=(k+1)^2\)

所以 \(f_p(x)=\sum\limits_{k\geq 0}(k+1)^2x^k=\sum\limits_{k\geq 0}k^2x^k+\sum\limits_{k\geq 0}kx^k+\dfrac{1}{1-x}\)

由于 \(\sum\limits_{k\geq 0}a_kx^k=F(x)\Rightarrow\sum\limits_{k\geq 0}t(k)a_kx^k=t(x\mathrm{D})F(x)\) ,其中 \(x\mathrm D\) 代表对原式求导后乘 \(x\) 的算子,\(t\) 是一个多项式。

这用第二类 \(\text{Stirling Numbers}\)\(t(k)\) 每一项转为下降幂,再用具体数学的习题 \((6.13)\) 容易证明。

所以

\[f_p(x)=(x\mathrm{D})^2\dfrac{1}{1-x}+2(x\mathrm D)\dfrac{1}{1-x}+\dfrac{1}{1-x} \]

\[=(x\mathrm{D})\dfrac{x}{(1-x)^2}+\dfrac{2x}{(1-x)^2}+\dfrac{1}{1-x} \]

\[=x\dfrac{(x^2-2x+1)-(2x-2)x}{(1-x)^4}+\dfrac{2x}{(1-x)^2}+\dfrac{1}{1-x} \]

\[=x\dfrac{1-x+2x}{(1-x)^3}+\dfrac{2x}{(1-x)^2}+\dfrac{1}{1-x} \]

\[=\dfrac{x+x^2+2x-2x^2+x^2-2x+1}{(1-x)^3}=\dfrac{1+x}{(1-x)^3} \]

这式子十分简洁,于是 \(\sigma_0(i)^2\)\(\text{dgf}\) 就是:

\[\prod\limits_{p\in \text{Prime}}\dfrac{1+p^{-z}}{(1-p^{-z})^3}=\prod\limits_{p\in \text{Prime}}\dfrac{1-p^{-2z}}{(1-p^{-z})^4}=\dfrac{\zeta(z)^4}{\zeta(2z)} \]

这可以看做 \(\zeta(z)^4\times \dfrac{1}{\zeta(2z)}\) ,而 \(\zeta(2z)=\prod\limits_{p\in \text{Prime}}1-p^{-2z}\) 仅在分解后各个质数次幂为 \(2\) 的数中有值。

所以 \(\dfrac{1}{\zeta(2z)}\) 仅在 \(\text{PN}\) 中的一部分中有值,个数小于 \(\sqrt n\) ,能直接搜出来。

剩下的就是求 \(\zeta(z)^4\) 的前缀和,可以杜教筛。

这相当于 \(\zeta(z)^2\) 与自身的卷积,而为了求出 \(\zeta(z)^2\) 需要 \(\zeta(z)\) 与自身的卷积。

\(\zeta(z)^2\)\(\zeta(z)\) 需要的仅是 \(\left\lfloor\dfrac{n}{k}\right\rfloor\) 处的 \(O(\sqrt n)\) 个值,在预处理出前 \(n^{\frac{2}{3}}\) 的前缀和后用杜教筛的时间复杂度是 \(O(n^{\frac{2}{3}})\)

由于我比较懒,在处理前 \(n^{\frac{2}{3}}\)\(\zeta(z)^2,\zeta(z)^4\) 的前缀和时直接用了 \(\text{P5495}\) 的方法多一个 \(\ln\ln n\) 算。

那这样预处理的范围可以适当调整一下。

总时间复杂度是 \(O(n^{\frac{2}{3}}\ln\ln n)\)\(O(n^{\frac{2}{3}})\) ,取决于是否线性筛。

但即使是 \(O(n^{\frac{2}{3}}\ln\ln n)\) 的方法在 \(\text{loj}\) 上也能跑的十分快,不可思议地快过了一部分 \(\text{Min25}\) 筛。

代码:

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int NN=4641588;
const ll mod=998244353;
ll n,x;bool b[NN+10];char ch;
int N,nn,p[NN+10],prm;ll ans;
ll sig_1[NN+10],sig_2[NN+10],sig_3[NN+10],sig_4[NN+10];
inline ll add(ll x,ll y){return x+y<mod?x+y:x+y-mod;}
void write(int x){if(x>9)write(x/10);putchar(48+x%10);}
inline void summing(ll *f,ll *g){
	register int i,j,pj,k;
	for(i=1;i<=N;++i)g[i]=f[i];
	for(j=1;j<=prm;++j){
		pj=p[j];
		for(i=1;(k=i*pj)<=N;++i)g[k]=add(g[k],g[i]);
	}
}
void pre_work(){
	register int i,j,k;
	for(i=2;i<=N;++i){
		if(!b[i])p[++prm]=i;
		for(j=1;j<=prm&&(k=i*p[j])<=N;++j){
			b[k]=1;if(i%p[j]==0)break;
		}
	}
	for(i=1;i<=N;++i)sig_1[i]=1;
	summing(sig_1,sig_2);
	summing(sig_2,sig_3);
	summing(sig_3,sig_4);
	for(i=1;i<=N;++i)sig_2[i]=add(sig_2[i],sig_2[i-1]);
	for(i=1;i<=N;++i)sig_4[i]=add(sig_4[i],sig_4[i-1]);
}
ll H2[NN];
ll getsum_2(ll n,ll k){
	if(n<=N)return sig_2[n];
	if(H2[k])return H2[k];
	register ll l,r,res=0;
	for(l=1;l<=n;l=r+1){
		r=n/(n/l);
		res=add(res,(n/l)*(r-l+1)%mod);
	}
	return H2[k]=res;
}
ll getsum_4(ll n,ll k){
	if(n<=N)return sig_4[n];
	register ll l,r,t,res=getsum_2(n,k);
	for(l=2;l<=n;l=r+1){
		t=n/l;r=n/t;
		res=add(res,(getsum_2(r,k*t)-getsum_2(l-1,k*(n/(l-1))))*getsum_2(t,k*l)%mod);
	}
	return res;
}
void getsum(int _n,int id,int v){
	ll _n_=1ll*_n*_n;
	ans=(ans+getsum_4(n/_n_,_n_)*v)%mod;
	register int i,j;
	for(i=id;i<=prm&&_n<=nn/p[i];++i)getsum(_n*p[i],i+1,-v);
}
main(){
	ch=getchar();while(ch>47)n=(n<<3)+(n<<1)+(ch^48),ch=getchar();
	N=pow(n,0.66666);if(N>2e5)N/=2;pre_work();
	nn=sqrt(n);register int i;
	for(i=1;i<=prm;++i)if(p[i]>nn)break;prm=i-1;
	getsum(1,1,1);ans=(ans-getsum_2(n,1))%mod;
	ans=(ans+mod)%mod*(mod+1)/2%mod;
	write(ans);
}
posted @ 2021-10-17 12:38  Y_B_X  阅读(54)  评论(0编辑  收藏  举报