莫比乌斯反演
推荐教程
莫比乌斯反演常用的两种形式:
题目
Visible Lattice Points
题意
在一个\(n\times n\)的坐标轴上,问你有多少个点可以被\((0,0,0)\)看到。
思路
我们知道一个点\((x,y,z)\)要想被\((0,0,0)\)看到,那么\((x,y,z)\)与\((0,0,0)\)的连线上就不能有其他点存在,因此这个题求得就是\(\sum\limits_{i=1}^{n}\sum\limits_{j=1}^{n}\sum\limits_{j=1}^{n}[gcd(i,j,k)=1]\)。
我们首先定义:
然后我们用公式\((2)\)来反演:
由\(F(n)\)的定义我们知道
反演得到
将\(f(1)\)代入所求式子可以得到
但是需要注意一点,那就是点在某坐标平面和某坐标轴上的情况,因此其实最终答案应该是\(\sum\limits_{i=1}^{n}\mu(i)((n/i)^3+3(n/i)^2)+3\)。
由于这个题目的\(T\leq50\),因此我们可以用\(O(n)\)来写,但是如果\(T\)大一点的化就需要使用整除分块来写,这里就只贴整除分块的代码了。
代码实现如下
#include <set>
#include <map>
#include <deque>
#include <queue>
#include <stack>
#include <cmath>
#include <ctime>
#include <bitset>
#include <cstdio>
#include <string>
#include <vector>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
typedef pair<LL, LL> pLL;
typedef pair<LL, int> pLi;
typedef pair<int, LL> piL;;
typedef pair<int, int> pii;
typedef unsigned long long uLL;
#define lson rt<<1
#define rson rt<<1|1
#define lowbit(x) x&(-x)
#define name2str(name) (#name)
#define bug printf("*********\n")
#define debug(x) cout<<#x"=["<<x<<"]" <<endl
#define FIN freopen("in","r",stdin)
#define IO ios::sync_with_stdio(false),cin.tie(0)
const double eps = 1e-8;
const int mod = 1e9 + 7;
const int maxn = 1e6 + 7;
const double pi = acos(-1);
const int inf = 0x3f3f3f3f;
const LL INF = 0x3f3f3f3f3f3f3f3fLL;
int t, n, cnt;
int isp[maxn], v[maxn], mu[maxn];
void init() {
mu[1] = 1;
for(int i = 2; i < maxn; ++i) {
if(!v[i]) {
v[i] = 1;
isp[cnt++] = i;
mu[i] = -1;
}
for(int j = 0; j < cnt; ++j) {
if(isp[j] * i > maxn) break;
v[i*isp[j]] = 1;
if(i % isp[j] == 0) {
mu[i*isp[j]] = 0;
break;
}
mu[i*isp[j]] = -mu[i];
}
}
for(int i = 2; i < maxn; ++i) mu[i] += mu[i-1];
}
int main() {
#ifndef ONLINE_JUDGE
FIN;
#endif
init();
scanf("%d", &t);
while(t--) {
scanf("%d", &n);
LL ans = 0;
for(int l = 1, r; l <= n; l = r + 1) {
r = min(n, n / (n / l));
int x = n / l;
LL sum = 1LL * x * x * x + 3LL * x * x + 3LL * x;
ans += sum * (mu[r] - mu[l-1]);
}
printf("%lld\n", ans);
}
return 0;
}
下面的代码基本上都和上面的差不多所以就不写代码啦~
GCD
答案为\(\sum\limits_{i=1}^{n}\sum\limits_{j=1}^{m}[gcd(i,j)=k]=\sum\limits_{i=1}^{\lfloor\frac{n}{k}\rfloor}\sum\limits_{j=1}^{\lfloor\frac{m}{k}\rfloor}[gcd(i,j)=1]\)。
定义
则
所以最终答案为
不过要记得去重哦~
小D的Lemon
其中
我们首先将\(gcd(i,j)\)提出来,然后变成\(\prod\limits_{k=1}^{min(n,m)}g(k)^{\prod\limits_{i=1}^{n}\prod\limits_{j=1}^{m}[gcd(i,j)=k]}\),通过反演我们可以得到
由于\(T\)比较大,因此\(O(n)\)的复杂度是无法通过的,因此我们需要预处理出\(\prod\limits_{t|T}g(t)\),然后就可以用整除分块处理即可,总复杂度为\(O(nlog(n)+T\sqrt n)\)。
#include <set>
#include <map>
#include <deque>
#include <queue>
#include <stack>
#include <cmath>
#include <ctime>
#include <bitset>
#include <cstdio>
#include <string>
#include <vector>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
typedef pair<LL, LL> pLL;
typedef pair<LL, int> pLi;
typedef pair<int, LL> piL;;
typedef pair<int, int> pii;
typedef unsigned long long uLL;
#define lson rt<<1
#define rson rt<<1|1
#define lowbit(x) x&(-x)
#define name2str(name) (#name)
#define bug printf("*********\n")
#define debug(x) cout<<#x"=["<<x<<"]" <<endl
#define FIN freopen("in","r",stdin)
#define IO ios::sync_with_stdio(false),cin.tie(0)
const double eps = 1e-8;
const int mod = 1e9 + 7;
const int maxn = 3e5 + 6;
const double pi = acos(-1);
const int inf = 0x3f3f3f3f;
const LL INF = 0x3f3f3f3f3f3f3f3fLL;
int t, n, m, cnt;
int isp[maxn], mu[maxn], v[maxn];
LL g[maxn], f[maxn], inv[maxn], invv[maxn];
LL qpow(LL x, LL n) {
LL res = 1;
while(n) {
if(n & 1) res = res * x % mod;
x = x * x % mod;
n >>= 1;
}
return res;
}
void init() {
mu[1] = g[1] = 1;
for(int i = 2; i < maxn; ++i) {
if(!v[i]) {
g[i] = 1;
mu[i] = -1;
isp[cnt++] = i;
}
for(int j = 0; j < cnt && i * isp[j] < maxn; ++j) {
v[i*isp[j]] = 1;
g[i*isp[j]] = g[i] + 1;
mu[i*isp[j]] = -mu[i];
if(i % isp[j] == 0) {
mu[i*isp[j]] = 0;
break;
}
}
}
for(int i = 1; i < maxn; ++i) {
f[i] = 1;
invv[i] = qpow(g[i], mod - 2);
}
for(int i = 2; i < maxn; ++i) {
for(int j = i; j < maxn; j += i) {
if(mu[j/i] == 1) f[j] = f[j] * g[i] % mod;
else if(mu[j/i] == -1) f[j] = f[j] * invv[i] % mod;
}
}
f[0] = inv[0] = inv[0] = inv[1] = 1;
for(int i = 2; i < maxn; ++i) {
f[i] = f[i] * f[i-1] % mod;
inv[i] = qpow(f[i], mod - 2);
}
}
int main() {
init();
scanf("%d", &t);
while(t--) {
scanf("%d%d", &n, &m);
if(n > m) swap(n, m);
LL ans = 1;
for(int l = 1, r; l <= n; l = r + 1) {
r = min(n / (n / l), m / (m / l));
ans = ans * qpow(f[r] * inv[l-1] % mod, 1LL * (n / l) * (m / l) % (mod - 1)) % mod;
}
printf("%lld\n", ans);
}
return 0;
}
小清新数论
#include <set>
#include <map>
#include <deque>
#include <queue>
#include <stack>
#include <cmath>
#include <ctime>
#include <bitset>
#include <cstdio>
#include <string>
#include <vector>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
typedef pair<LL, LL> pLL;
typedef pair<LL, int> pLi;
typedef pair<int, LL> piL;;
typedef pair<int, int> pii;
typedef unsigned long long uLL;
#define lson rt<<1
#define rson rt<<1|1
#define lowbit(x) x&(-x)
#define name2str(name) (#name)
#define bug printf("*********\n")
#define debug(x) cout<<#x"=["<<x<<"]" <<endl
#define FIN freopen("in","r",stdin)
#define IO ios::sync_with_stdio(false),cin.tie(0)
const double eps = 1e-8;
const int mod = 998244353;
const int maxn = 1e7 + 7;
const double pi = acos(-1);
const int inf = 0x3f3f3f3f;
const LL INF = 0x3f3f3f3f3f3f3f3fLL;
int n, cnt;
int v[maxn], isp[maxn], phi[maxn], mu[maxn];
void init() {
mu[1] = phi[1] = 1;
for(int i = 2; i <= n; ++i) {
if(!v[i]) {
v[i] = 1;
mu[i] = -1;
phi[i] = i - 1;
isp[cnt++] = i;
}
for(int j = 0; j < cnt; ++j) {
if(isp[j] > n / i) break;
v[isp[j]*i] = 1;
mu[i*isp[j]] = -mu[i];
if(i % isp[j] == 0) {
mu[i*isp[j]] = 0;
phi[i*isp[j]] = phi[i] * isp[j] % mod;
break;
}
phi[i*isp[j]] = phi[i] * (isp[j] - 1) % mod;
}
}
for(int i = 2; i <= n; ++i) (phi[i] += phi[i-1]) %= mod;
}
int main() {
scanf("%d", &n);
init();
LL ans = 0;
for(int i = 1; i <= n; ++i) {
ans = ((ans + mu[i] * ((2LL * phi[n/i] % mod + mod) % mod - 1 + mod)) % mod + mod) % mod;
}
printf("%lld\n", ans);
return 0;
}