Luogu1829 JZPTAB

JZPTAB

\(\sum_{i=1}^n\sum_{j=1}^mlcm(i,j)\)

\(=\sum_{i=1}^n\sum_{j=1}^m\frac{ij}{\gcd(i,j)}\)

枚举gcd,这里默认n<m

\(=\sum_{p=1}^n\frac 1 p\sum_{i=1}^n\sum_{j=1}^mij[\gcd(i,j)=p]\)

\(=\sum_{p=1}^n\frac 1 p\sum_{i=1}^{n/p}\sum_{j=1}^{m/p}ijp^2[\gcd(i,j)=1]\)

\(=\sum_{p=1}^np\sum_{i=1}^{n/p}\sum_{j=1}^{m/p}ij[\gcd(i,j)=1]\)

推到这我一般喜欢接着推下去,不过观察了下你谷的题解发现可以适当地设一些函数,方便后面的思考、写代码等等

例如这里我们设\(g(x,y)=\sum_{i=1}^x\sum_{j=1}^yij[gcd(i,j)=1]\),那么ans\(=\sum_{p=1}^np\cdot g(\frac np,\frac mp)\),假设我们能快速地求出g值,那么我们就可以打个两个数的数论分块,在根号求出ans

然后我们考虑\(g(x,y)=\sum_{i=1}^x\sum_{j=1}^yij[gcd(i,j)=1]\)

\(g(x,y)=\sum_{i=1}^x\sum_{j=1}^yij\sum_{d|i,d|j}\mu(d)\)

\(g(x,y)=\sum_{d=1}^n\mu(d)d^2\sum_{i=1}^{x/d}\sum_{j=1}^{y/d}ij\)

\(g(x,y)=\sum_{d=1}^n\mu(d)d^2(\sum_{i=1}^{x/d}i)(\sum_{j=1}^{y/d}j)\)

显然\(\sum_{i=1}^ni=\frac{n(n+1)}2\),为了简便,我们设\(f(x)=\frac{x(x+1)}2\)

\(g(x,y)=\sum_{d=1}^n\mu(d)d^2f(\lfloor\frac x d\rfloor)f(\lfloor\frac y d\rfloor)\)

这个显然也可以用数论分块那套理论,复杂度为根号

两个根号套一起就是\(O(n)\)了。这题\(n\)\(10^7\),要稍微卡卡常数。。。

不用卡常数,交上去第一遍WA了,define int longlong就A了。。。

没开O2,一共9248ms,跑的飞慢

#include <cstdio>
#include <functional>
using namespace std;

#define int long long

bool vis[10000010];
int prime[1000000], tot;
long long mu[10000010];

const int fuck = 10000000, p = 20101009;

int n, m;

int f(int x) { return x * (long long)(x + 1) / 2 % p; }

int g(int x, int y)
{
	int res = 0;
	if (x > y) swap(x, y);
	for (int i = 1, j; i <= x; i = j + 1)
	{
		j = min(x / (x / i), y / (y / i));
		res = (res + (mu[j] - mu[i - 1]) * (long long)f(x / i) % p * (long long)f(y / i) % p) % p;
		if (res < 0) res += p;
	}
	return res;
}

signed main()
{
	mu[1] = 1;
	for (int i = 2; i <= fuck; i++)
	{
		if (vis[i] == false) prime[++tot] = i, mu[i] = -1;
		for (int j = 1; j <= tot && i * prime[j] <= fuck; j++)
		{
			vis[i * prime[j]] = true;
			if (i % prime[j] == 0)
				break;
			mu[i * prime[j]] = -mu[i];
		}
		mu[i] *= i * i;
		mu[i] += mu[i - 1];
		mu[i] %= p;
	}
	int n, m;
	scanf("%lld%lld", &n, &m);
	if (n > m) swap(n, m);
	int ans = 0;
	for (int i = 1, j; i <= n; i = j + 1)
	{
		j = min(n / (n / i), m / (m / i));
		ans = (ans + (j - i + 1) * (long long)(i + j) / 2 % p * (long long)g(n / i, m / i) % p) % p;
	}
	printf("%lld\n", ans);
	return 0;
}
posted @ 2019-01-20 11:42  ghj1222  阅读(127)  评论(0编辑  收藏  举报