【题解】Luogu-P1447 [NOI2010] 能量采集

Description

  • 给定整数 \(n, m\),求

\[\sum_{i = 1}^n \sum_{j = 1}^m 2 \gcd(i, j) - 1 \]

  • 对于 \(100\%\) 的数据:\(1 \le n, m \le 10^5\)

Solution

不妨设 \(n\le m\)

\[\begin{aligned} \sum_{i = 1}^n \sum_{j = 1}^m 2 \gcd(i, j) - 1 & = - nm + 2 \sum_{i = 1}^n \sum_{j = 1}^m \gcd(i, j) \end{aligned} \]

把后面一项单独拎出来,就是大家喜闻乐见的欧拉反演。

\[\begin{aligned} \sum_{i = 1}^n \sum_{j = 1}^m \gcd(i, j) & = \sum_{i = 1}^n \sum_{j = 1}^m \sum_{d \mid \gcd(i, j)} \varphi(d) \\ & = \sum_{d = 1}^n \varphi(d) \sum_{i = 1}^n [d\mid i] \sum_{j = 1}^m [d\mid j] \\ & = \sum_{d = 1}^n \varphi(d) \left\lfloor\dfrac{n}{d}\right\rfloor \left\lfloor\dfrac{m}{d}\right\rfloor \end{aligned} \]

预处理前缀和 + 整除分块即可。

14:21 我试试能不能用杜教筛来卡最优解

14:31 我杜教筛 TLE 了

14:43 我杜教筛用了 \(37ms\)???用容斥才 \(25ms\)

14:48 哦是因为数据范围太小然后杜教筛常数太大

Code

// 18 = 9 + 9 = 18.
#include <iostream>
#include <cstdio>
#include <unordered_map>
#define Debug(x) cout << #x << "=" << x << endl
typedef long long ll;
using namespace std;

const int MAXN = 2154 + 5;
const int N = 2154;

int p[MAXN], phi[MAXN];
ll sum[MAXN];
bool vis[MAXN];

void pre()
{
	phi[1] = sum[1] = 1;
	for (int i = 2; i <= N; i++)
	{
		if (!vis[i])
		{
			p[++p[0]] = i;
			phi[i] = i - 1;
		}
		for (int j = 1; j <= p[0] && i * p[j] <= N; j++)
		{
			vis[i * p[j]] = true;
			if (i % p[j] == 0)
			{
				phi[i * p[j]] = phi[i] * p[j];
				break;
			}
			phi[i * p[j]] = phi[i] * phi[p[j]];
		}
		sum[i] = sum[i - 1] + phi[i];
	}
}

unordered_map<ll, ll> dp;

ll sublinear(int n)
{
	if (n <= N)
	{
		return sum[n];
	}
	if (dp.find(n) != dp.end())
	{
		return dp[n];
	}
	ll res = (ll)n * (n + 1) / 2;
	for (int l = 2, r; l <= n; l = r + 1)
	{
		int k = n / l;
		r = n / k;
		res -= (r - l + 1) * sublinear(k);
	}
	return dp[n] = res;
}

ll getsum(int l, int r)
{
	return sublinear(r) - sublinear(l - 1);
}

ll block(int n, int m)
{
	ll res = 0;
	for (int l = 1, r; l <= n; l = r + 1)
	{
		int k1 = n / l, k2 = m / l;
		r = min(n / k1, m / k2);
		res += getsum(l, r) * k1 * k2;
	}
	return res;
}

int main()
{
	pre();
	int n, m;
	scanf("%d%d", &n, &m);
	if (n > m)
	{
		swap(n, m);
	}
	printf("%lld\n", - (ll)n * m + 2 * block(n, m));
	return 0;
}
posted @ 2022-01-18 14:51  mango09  阅读(42)  评论(0编辑  收藏  举报
-->