HDU 6053 TrickGCD (莫比乌斯反演)
Description
给出序列\(A\) ,问有多少不同的序列\(B\)满足以下条件:
- \(1 \leqslant b_i \leqslant a_i\)
- \(gcd(B_1,B_2,\cdots, B_n) \geqslant 2\)
Input
第一行一个整数 \(T\)表示用例组数,每组用例首先输入一个整数 \(n\) 表示序列长度,之后输入\(n\)个整数表示序列\(A\)。\(1 \leqslant n,a_i \leqslant 10^5\)。
Output
对于每组用例,输出满足条件的序列\(B\)的数量。结果模\(1e9+7\)。
Sample Input
1
4
4 4 4 4
Sample Output
Case #1: 17
Solution
令\(f(k)\)表示\(gcd(B)=k\)的序列数量,\(g(k)\)表示\(k|gcd(B)\)的数量,有
\[g(k)=\sum_{k|d}{f(d)}
\]
莫比乌斯反演得
\[f(k)=\sum_{k|d}{\mu(\frac{d}{n})g(d)}
\]
最终答案为
\[ans=g(1)-f(1)=g(1)-\sum_{d=1}^{min\{a_i\}}{\mu(d)g(d)}=-\sum_{d=2}^{min\{a_i\}}{\mu(d)g(d)}
\]
其中
\[g(d)=\prod_{i=1}^{n}{\left \lfloor \frac{a_i}{d} \right \rfloor}
\]
预处理莫比乌斯函数的时间复杂度是\(O(nlogn)\),然后累乘累加的时间复杂度为\(O(n^2)\),会超时。此时注意\(a_i\)最大只有\(10^5\)且\(\left \lfloor a_i/d \right \rfloor\)会有很多重复,故可以用如下方法在\(O(nlogn)\)时间内求出每一个\(g(d)\)
令\(sum[i]\)表示\(a_k \leqslant i\)的\(a_k\)的个数\((1 \leqslant i \leqslant 10^5)\),由于当\(j\times d \leqslant a_i \leqslant j \times d + d- 1\)时都有\(\left \lfloor a_i/d \right \rfloor=j\),于是
\[g(d)=\prod_{i=1}^{n}{\left \lfloor \frac{a_i}{d} \right \rfloor}=\prod_{j=1}^{\lfloor \frac{max\{a_i\}}{d} \rfloor}{j^{sum[j\times d+d-1]-sum[j\times d-1]}}
\]
整体时间复杂度为\(O(Tnlongn)\)
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int INF = 0x3f3f3f3f;
const ll mod = 1e9 + 7;
const int N = 1e5 + 10;
int prime[N], mu[N];
bool flag[N];
void Mobius(int n)
{
memset(flag, false, sizeof(flag));
mu[1] = 1;
int tot = 0;
for (int i = 2; i <= n; i++)
{
if (!flag[i]) prime[tot++] = i, mu[i] = -1;
for (int j = 0; j < tot && i * prime[j] <= n; j++)
{
int k = i * prime[j];
flag[k] = true;
if (i % prime[j] == 0) { mu[k] = 0; break; }
else mu[k] = -mu[i];
}
}
}
ll power(ll a, ll n, ll m)
{
ll ans = 1;
while (n)
{
if (n & 1) ans = ans * a % m;
a = a * a % m;
n >>= 1;
}
return ans;
}
int sum[2 * N];
int main()
{
int T;
scanf("%d", &T);
Mobius(1e5);
for (int cas = 1; cas <= T; cas++)
{
int n;
scanf("%d", &n);
memset(sum, 0, sizeof(sum));
int minx = INF, maxx = 0;
for (int i = 1; i <= n; i++)
{
int x;
scanf("%d", &x);
sum[x]++;
minx = min(minx, x);
maxx = max(maxx, x);
}
for (int i = 1; i <= 2 * maxx; i++) sum[i] = sum[i - 1] + sum[i];
ll ans = 0;
for (int i = 2; i <= minx; i++)
{
if (mu[i] == 0) continue;
ll s = 1;
for (int j = 1; j * i <= maxx; j++)
{
int num = sum[j * i + i - 1] - sum[j * i - 1];
s = s * power(j, num, mod) % mod;
}
ans = ((ans - mu[i] * s) % mod + mod) % mod;
}
printf("Case #%d: %lld\n", cas, ans);
}
return 0;
}