Loading

关于积性函数求和的一点想法

好像会了一个 \(O(n^{0.5+o(1)})\) 的积性函数求和方法(得到所有块筛)。不过 OI 中不会有用就是了。

UPD : 是 \(\sqrt n \operatorname{polylog}(n)\) 的。

Part1

现在假设我们能解决如下问题:

问题:对于两个序列 \(a,b\),我们已知其在 \(n/k\) 上的前缀和。
\(f(z) = \sum_{xy = z} a(x) b(y)\)\(n/k\) 上的前缀和。

在下面,我们称这是两个序列的“卷积”,用 \(\times\) 表示,\(k\)\(a\) 卷起来是 \(a^k\)

首先如果 \(f(p)\) 是多项式,是可以直接用“冷群筛”\(k \log V\) 次卷积解决的(\(V\) 是值域。如果有负数那可能要取模一下再做)。

但是其实有一种比较通用的方法。(还没有写过,不过感觉很对。)

类似 Min_25 筛的想法:

\(f(p)\) 的值拆成若干个前缀和好算的积性函数(是这个算法使用的条件,同时也是 Min_25 筛使用的条件)。

考虑先求出那些好算积性函数在质数处的的前缀和。这样我们就得到了 \(f(p)\) 在质数处的前缀和。

先考虑如果得到了 \(f(p)\) 在质数处的前缀和(设为 \(g\)\(g\) 只有质数处有值),接下来该怎么做。

考虑计算 \(g\) 的 “\(\exp\)\(h=\sum\limits_{i=0} \frac{g^{i}}{i!}\)

注意到这个是一个积性函数,且 \(h(p^k) = \frac{f(p)^k}{k!}\)

因此有 \(h(p)=f(p)\),从 \(h\) 得到 \(f\) 只要再卷一次 \(\textrm{Powerful Number}\) 即可。


而从好算积性函数得到其质数处的值,就只要做上面的逆过程:

  • \(\textrm{Powerful Number}\) 得到 \(h(p^k) = \frac{f(p)^k}{k!}\)
  • \(h\) 取 “\(\ln\)”:\(\sum_{i=1} (-1)^{i-1} \frac{(h-e)^i}{i}\)。其中 \(e(x)=[x=1]\)

每次 “\(\ln,\exp\)” 都只需要枚举到 \(\log n\) 级别。

这样,如果我们把 \(f\) 拆成 \(k\) 个好算的积性函数,那么需要 \(k \log n\) 次卷积。

Part2

现在我们要解决上面提到的问题。

如果 \(x > \sqrt n\)\(y > \sqrt n\)。这里不妨 \(x>\sqrt n\)

\(t=n/x\)\(\lfloor n/x/y \rfloor = \lfloor \frac{t}{y} \rfloor\)

因此枚举 \(y\),对 \(t\) 做前缀和优化就是 \(\Theta(\sqrt n \log n)\) 的。


现在是 \(x,y \le \sqrt n\),怎么办呢!!

首先 \(xy \le \sqrt n\) 暴力。

然后就是要在 \(\lfloor n/(xy) \rfloor\) 上记录 \(a_x \times b_y\) 的值。

首先这个 \(\lfloor n/(xy) \rfloor\) 事实上是最大的 \(p\) 满足 \(xy \le n/p\)

\(\ln\)\(\ln(x)+\ln(y) \le \ln(n/p)\)

考虑对 \(\ln\) 做一个估计:给左边乘以 \(B\),然后向上取整。

也就是说我们找最后一个 \(\lceil B\ln(x) \rceil + \lceil B\ln(y)\rceil \le B\ln(n/p)\)\(p\)

这样 \(\ln\) 的值域是 \(B \log n\),然后对 \(\lceil B\ln(x) \rceil,\lceil B\ln(y)\rceil\) 做 FFT 的复杂度是 \(B \log^2 n\)

但是这样做显然是不正确的。当 \(|\ln(x) + \ln(y) - \ln(n/p)| \le \frac{2}{B}\) 的时候可能会出问题。

对于这样的 \((x,y,p)\),有:

\[|\ln(x) + \ln(y) + \ln(p) - \ln(n)| \le \frac{2}{B} \]

\[|\ln(xyp) - \ln(n)| \le \frac{2}{B} \]

注意到 \(\ln(x)' = \frac{1}{x}\),因此 \(|xyp-n|\)\(\frac{\frac{2}{B}}{\frac{1}{n}}\) 级别也就是 \(\Theta(\frac{n}{B})\) 级别的!

那么对所以暴力枚举 \(s=xyp\),不妨有 \(s \in [l,r]\) 内。

区间筛出 \([l,r]\) 的质因子,然后暴力枚举 \(x,y\) 即可。

\(B=\sqrt n\) 时算法有 一个比较显然的上界是 \(\Theta(\sqrt n\log n\max(d))\),可以看作 \(O(n^{0.5+o(1)})\)

UPD:

EI 2023/7/12 9:07:49
我们搜了一下, 这里把 sum_{xyz<=n} 1 的误差已经控制到了 O(n^0.45), 所以对长为 sqrt(n) 的区间来说已经够用了 https://en.wikipedia.org/wiki/Divisor_summatory_function#Piltz_divisor_problem

这里提到了 \(\sum_{i \le n} d_3(i) = n P_3(\ln n) + \mathcal O(n^{43 / 96 + \varepsilon})\),其中 \(P_3\) 是二次多项式。

对长度为根号的区间估计可以忽略掉那个 \(\mathcal O(n^{43 / 96 + \varepsilon})\),时间复杂度即为 \(\Theta(\sqrt n \log^2 n)\)

上面提到的积性函数要做 \(\log n\) 次卷积,因此总复杂度为 \(\Theta(\sqrt n \log^3 n)\)


一份代码实现(只实现了 Part2):

#include<bits/stdc++.h>
#define L(i, j, k) for(int i = (j); i <= (k); ++i)
#define R(i, j, k) for(int i = (j); i >= (k); --i)
#define ll long long 
#define vi vector < int > 
#define sz(a) ((int) (a).size())
#define ll long long 
#define ull unsigned long long
#define me(a, x) memset(a, x, sizeof(a)) 
using namespace std;
const int mod = 998244353, _G = 3, N = 1 << 21, inv2 = (mod + 1) / 2;
#define add(a, b) (a + b >= mod ? a + b - mod : a + b)
#define dec(a, b) (a < b ? a - b + mod : a - b)
int qpow(int x, int y = mod - 2) {
	int res = 1;
	for(; y; x = (ll) x * x % mod, y >>= 1) if(y & 1) res = (ll) res * x % mod;
	return res;
}
int rt[N * 2 + 3], Lim;
void Pinit(int x) {
	for(Lim = 1; Lim <= x; Lim <<= 1) ;
	for(int i = 1; i < Lim; i <<= 1) {
		int sG = qpow (_G, (mod - 1) / (i << 1));
		rt[i] = 1;
		L(j, i + 1, i * 2 - 1) rt[j] = (ll) rt[j - 1] * sG % mod;
	}
}
struct poly {
	vector<int> a;
	int size() { return sz(a); }
	int & operator [] (int x) { return a[x]; }
	int v(int x) { return x < 0 || x >= sz(a) ? 0 : a[x]; }
	void clear() { vector<int> ().swap(a); }
	void rs(int x = 0) { a.resize(x); }
	poly (int n = 0) { rs(n); }
	poly (vector<int> o) { a = o; }
	poly (const poly &o) { a = o.a; }
	poly Rs(int x = 0) { vi res = a; res.resize(x); return res; }
	inline void dif() {
		int n = sz(a);
		for (int l = n >> 1; l >= 1; l >>= 1) 
			for(int j = 0; j < n; j += l << 1) 
				for(int k = 0, *w = rt + l; k < l; k++, w++) {
					int x = a[j + k], y = a[j + k + l];
					a[j + k] = add(x, y);
					a[j + k + l] = (ll) * w * dec(x, y) % mod;
				}
	}
	void dit () {
		int n = sz(a);
		for(int i = 2; i <= n; i <<= 1) 
			for(int j = 0, l = (i >> 1); j < n; j += i) 
				for(int k = 0, *w = rt + l; k < l; k++, w++) {
					int pa = a[j + k], pb = (ll) a[j + k + l] * *w % mod;
					a[j + k] = add(pa, pb), a[j + k + l] = dec(pa, pb);
				}
		reverse(a.begin() + 1, a.end());
		for(int i = 0, iv = qpow(n); i < n; i++) a[i] = (ll) a[i] * iv % mod;
	} 
	friend poly operator * (poly aa, poly bb) {
		if(!sz(aa) || !sz(bb)) return {};
		int lim, all = sz(aa) + sz(bb) - 1;
		for(lim = 1; lim < all; lim <<= 1);
		aa.rs(lim), bb.rs(lim), aa.dif(), bb.dif();
		L(i, 0, lim - 1) aa[i] = (ll) aa[i] * bb[i] % mod;
		aa.dit(), aa.a.resize(all);
		return aa;
	}
	poly Inv() {
		poly res, f, g;
		res.rs(1), res[0] = qpow(a[0]);
		for(int m = 1, pn; m < sz(a); m <<= 1) {
			pn = m << 1, f = res, g.rs(pn), f.rs(pn);
			for(int i = 0; i < pn; i++) g[i] = (*this).v(i);
			f.dif(), g.dif();
			for(int i = 0; i < pn; i++) g[i] = (ll) f[i] * g[i] % mod;
			g.dit();
			for(int i = 0; i < m; i++) g[i] = 0;
			g.dif();
			for(int i = 0; i < pn; i++) g[i] = (ll) f[i] * g[i] % mod;
			g.dit(), res.rs(pn);
			for(int i = m; i < min(pn, sz(a)); i++) res[i] = (mod - g[i]) % mod;
		} 
		return res.rs(sz(a)), res;
	}
	poly Shift (int x) {
		poly zm (sz(a) + x);
		L(i, max(-x, 0), sz(a) - 1) zm[i + x] = a[i];
		return zm; 
	}
	friend poly operator * (poly aa, int bb) {
		poly res(sz(aa));
		L(i, 0, sz(aa) - 1) res[i] = (ll) aa[i] * bb % mod;
		return res;
	}
	friend poly operator + (poly aa, poly bb) {
		vector<int> res(max(sz(aa), sz(bb)));
		L(i, 0, sz(res) - 1) res[i] = add(aa.v(i), bb.v(i));
		return poly(res);
	}
	friend poly operator - (poly aa, poly bb) {
		vector<int> res(max(sz(aa), sz(bb)));
		L(i, 0, sz(res) - 1) res[i] = dec(aa.v(i), bb.v(i));
		return poly(res);
	}
	poly & operator += (poly o) {
		rs(max(sz(a), sz(o)));
		L(i, 0, sz(a) - 1) (a[i] += o.v(i)) %= mod;
		return (*this);
	}
	poly & operator -= (poly o) {
		rs(max(sz(a), sz(o)));
		L(i, 0, sz(a) - 1) (a[i] += mod - o.v(i)) %= mod;
		return (*this);
	}
	poly & operator *= (poly o) {
		return (*this) = (*this) * o;
	}
} ;

ll n, sq;
bool Prime[N];
int p[N], ptot;
mt19937_64 orz;
void xxs(int x) {
	L(i, 2, x) {
		if(!Prime[i]) p[++ptot] = i;
		for(int j = 1; p[j] * i <= x && j <= ptot; ++j) {
			Prime[p[j] * i] = true;
			if(i % p[j] == 0) break;
		} 
	}
}
ll w[N], id1[N], id2[N], tp;
inline int getid(ll x) {
	return x <= sq ? id1[x] : id2[n / x];
}
double iv[N];

double lg[N];
vi Mul(vi A, vi B) {
	L(i, 1, sz(A) - 2) (A[i] += mod - A[i + 1]) %= mod;
	L(i, 1, sz(B) - 2) (B[i] += mod - B[i + 1]) %= mod;
	vi ans(tp + 1);
	vi av(sq + 1), bv(sq + 1);
	
	int bp = id1[sq] - 1;
	vi as(sq + 1), bs(sq + 1), cs(sq + 3);
	vi ma(sq + 3), mb(sq + 3), mc(sq + 3);
	L(i, 1, sq) 
		ma[i] = A[getid(i)], mb[i] = B[getid(i)];
	L(i, 1, bp) 
		as[i] = A[getid(n / i)], bs[i] = B[getid(n / i)];
	
	L(i, 1, sq) {
		L(j, 1, sq / i) {
			(mc[i * j] += (ll) ma[i] * mb[j] % mod) %= mod;
		}
	}
	
	vi asuf(bp + 2), bsuf(bp + 2);
	L(i, 1, bp) {
		asuf[i] = as[i];
		bsuf[i] = bs[i];
	}
	R(i, bp, 1) (asuf[i] += asuf[i + 1]) %= mod;
	R(i, bp, 1) (bsuf[i] += bsuf[i + 1]) %= mod;
	
	L(i, 1, sq) {
		L(j, 1, bp / i) {
			(cs[j] += (ll) mb[i] * (asuf[i * j] + mod - asuf[min(bp + 1, i * (j + 1))]) % mod) %= mod;
		}
		L(j, 1, bp / i) {
			(cs[j] += (ll) ma[i] * (bsuf[i * j] + mod - bsuf[min(bp + 1, i * (j + 1))]) % mod) %= mod;
		}
//		L(j, i, bp) 
//			(cs[j / i] += (ll) as[j] * mb[i] % mod) %= mod;
//		L(j, i, bp) 
//			(cs[j / i] += (ll) bs[j] * ma[i] % mod) %= mod;
	}
	
	vector < double > LG(tp + 3), slg(tp + 3);
	L(i, 1, sq) {
		LG[i] = log(i);
	}
	double lgn = log(n);
	L(i, 1, bp) {
		slg[i] = lgn - LG[i];
	}
	ll kanz = sq * 0.8;
	kanz = max(kanz, 1LL);
	
	vi BL(tp + 1);
	L(i, 1, sq) {
		BL[i] = ceil(LG[i] * kanz + 1e-10);
	}
	L(i, 1, bp) {
		slg[i] *= kanz;
	}
	
	ll WA = 0;
	int BOUND = BL[sq] * 2;
	vi pointer(BOUND + 1);
	int P = 1;
	R(i, BOUND, 0) {
		while(P <= bp && i < slg[P]) ++P; 
		pointer[i] = P - 1;
	}
	
	vi spin(BOUND + 1);
	cerr << "BOUND = " << 1. * BOUND << endl;
	
	poly MA(BL[sq] + 1), MB(BL[sq] + 1);
	L(i, 1, sq) 
		(MA[BL[i]] += ma[i]) %= mod, 
		(MB[BL[i]] += mb[i]) %= mod;
	MA *= MB;
	L(i, 1, BOUND) 
		spin[i] = MA[i];
	
	L(i, 1, sq) 
		L(j, 1, sq / i) 
			(spin[BL[i] + BL[j]] += mod - (ll) ma[i] * mb[j] % mod) %= mod;
	
	ll TL = n, TR = n;
	for(; TL && (log(n) - log(TL)) * kanz <= 2; --TL) ;
	TL -= 5, TL = max(TL, 1LL);
	
	ll lens = TR - TL + 1;
	vector < ll > val(lens + 1);
	vector < vector < pair < ll, int > > > FAC(lens + 1);
	L(i, 1, lens) val[i] = i + TL - 1;
	L(u, 1, ptot) {
		ll p = ::p[u];
		if(p > sq) continue; 
		ll ii = (p - TL % p) % p + 1;
		for(ll j = ii; j <= lens; j += p) {
			int nt = 0;
			while(val[j] % p == 0) {
				val[j] /= p, ++nt;
			}
			FAC[j].emplace_back(p, nt); 
		}
	}
	
	vector < pair < ll, int > > curs; 
	
	ll MULS = 0;
	auto check = [&] (int i, int j) {
		int p = pointer[BL[i] + BL[j]];
		ll ML = (ll) i * j * (p + 1);
		if((ll)i * j > sq && ML == MULS) {
			(spin[BL[i] + BL[j]] += mod - (ll) ma[i] * mb[j] % mod) %= mod;
			while((ll)i * j * (p + 1) <= n) ++p, ++WA;
			(cs[p] += (ll) ma[i] * mb[j] % mod) %= mod;
		}
		return ;
	} ;
	
	auto Dfs = [&] (auto self, int x, int i, int j) -> void {
		if(x == sz(curs)) {
			check(i, j);
			return ;
		}
		ll v = curs[x].first;
		int cnt = curs[x].second;
		L(t1, 0, cnt) {
			L(t2, 0, cnt - t1) {
				ll ni = i, nj = j;
				L(o, 1, t1) ni *= v;
				L(o, 1, t2) nj *= v;
				if(ni <= sq && nj <= sq) 
					self(self, x + 1, ni, nj);
			}
		} 
	};
	
	L(i, 1, lens) 
		MULS = TL + i - 1, curs = FAC[i], Dfs(Dfs, 0, 1, 1);
	
	L(i, 1, BOUND) {
		(cs[pointer[i]] += spin[i]) %= mod;
	}
	
	cerr << "WA = " << WA << ' ' << BOUND << endl;
//	L(i, 1, sq) {
//		L(j, sq / i + 1, sq) {
//			(cs[n / i / j] += (ll) ma[i] * mb[j] % mod) %= mod;
//		}
//	}
	
	L(i, 1, bp) if(cs[i]) (ans[getid(n / i)] += cs[i]) %= mod;
	L(i, 1, sq) if(mc[i]) (ans[getid(i)] += mc[i]) %= mod;
	
	R(i, tp - 1, 1) (ans[i] += ans[i + 1]) %= mod;
	return ans;
}
int f1[N], f2[N], trans[N];
int main() {
	ios :: sync_with_stdio(false);
	cin.tie(0); cout.tie(0); 
//	cin >> n;
	n = 1e10;
	Pinit(1 << 21);
	sq = sqrt(n);
	xxs(sq + 5);
	for(ll l = 1, r; l <= n; l = r + 1) {
		r = n / (n / l);
		w[++tp] = n / l;
		if(w[tp] <= sq) id1[w[tp]] = tp;
		else id2[n / w[tp]] = tp;
	}
	L(i, 1, tp) f1[i] = orz() % mod, f2[i] = orz() % mod; 
	L(i, 1, tp) {
		ll W = w[i];
		for(ll l = 1, r; l <= W; l = r + 1) 
			r = W / (W / l), 
			(trans[i] += (ll) (f1[getid(r)] - f1[getid(l - 1)] + mod) * f2[getid(W / l)] % mod) %= mod;
	}
	double clock1 = clock();
	cout << "solver1 : " << clock1 << endl;
	vi A(tp + 1), B(tp + 1);
	L(i, 1, tp) A[i] = f1[i], B[i] = f2[i];
	auto C = Mul(A, B);
	L(i, 1, tp) 
		if(trans[i] != C[i]) {
			cout << "wa" << endl;
		}
	double clock2 = clock();
	cout << "solver2 : " << clock2 - clock1 << endl;
	return 0;
}
posted @ 2023-07-11 15:45  zhoukangyang  阅读(13591)  评论(2编辑  收藏  举报