互质数对 题解

一、题目:

二、思路:

首先说明一下题面上的错误:【输入格式】中的第一行改为“第一行一个正整数\(N\)”。

然后来说一下这道题怎么做。

先说一下骗分。由于这道题数据较水,所以对于一个数\(a_i\),只需枚举与它相邻的五个数进行判断、更新答案即可。没排序做一遍、排完序再做一遍就能过去。很神奇吧!

好了,言归正传,我们来说说正解怎么做。

注意到题目中要求\(998244353\times a_i+a_j\)最小,就等价于以\(a_i\)为第一关键字、以\(a_j\)为第二关键字的有序对\((a_i,a_j)\)最小。所以一个很暴力的想法就是把\(a\)从小到大排序,然后对于一个数\(a_i\),若在\(a_{(i+1)\sim n}\)中存在与它互质的数,那么找见第一个与它互质的数\(a_j\),输出\(a_i\)\(a_j\)即可。

考虑如何优化。我们把优化的重心放在如何能快速判断是否\(\exists j\in [i+1,n]\),使得\(gcd(a_i,a_j)=1\)上。

考虑对每个\(a_i\),找出它的所有质因数\(p_1,p_2,\ldots,p_s\)。设\(y=p_1\times p_2\times \ldots \times p_s\),开一个数组\(val\),对于\(y\)的每一因数\(d\),让\(val[d]++\)

现在开始统计,我们统计的目标是有多少个\(j\in [i+1,n]\),满足\(a_j\)\(a_i\)不互质,记为\(cnt\)。枚举到一个新的\(i\)时,我们先撤销\(a_i\)\(val\)数组的影响,这样\(val\)数组始终保存的是下标为\(i+1\sim n\)的那些数的性质。考虑容斥原理,即:

  1. \(cnt\)加上与\(a_i\)至少有1个质因数相同的数的个数。
  2. \(cnt\)减去与\(a_i\)至少有2个质因数相同的数的个数。
  3. \(cnt\)加上与\(a_i\)至少有3个质因数相同的数的个数。
  4. ……

最后,只要\(cnt<n-i\),那么就存在与\(a_i\)互质的数。我们只需要找出来第一个与\(a_i\)互质的那个\(a_j\),输出答案。

如果所有的\(i\)都找遍了,也没有发现合法的情况,那只能输出-1了。

三、代码:

骗分骗到AC的代码:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

#define FILEIN(s) freopen(s".in", "r", stdin)
#define FILEOUT(s) freopen(s".out", "w", stdout)

inline int read(void) {
    int x = 0, f = 1; char ch = getchar();
    while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
    return f * x;
}

const int maxn = 200005;

int n, a[maxn], ans1 = 0x3f3f3f3f, ans2 = 0x3f3f3f3f;

int gcd(int a, int b) {
    if (b == 0) return a;
    return gcd(b, a % b);
}

inline void find(void) {
    for (int i = 1; i <= n; ++ i) {
        for (int j = i + 1; j <= min(i + 5, n); ++ j) {
            int A = a[i], B = a[j];
            if (A > B) swap(A, B);
            if (gcd(A, B) == 1) {
                if (A == ans1) {
                    ans2 = min(ans2, B);
                }
                else if (A < ans1) {
                    ans1 = A;
                    ans2 = B;
                }
            }
        }
    }
}

int main() {
    FILEIN("cprime"); FILEOUT("cprime");
    n = read();
    for (int i = 1; i <= n; ++ i) a[i] = read();
    find();
    sort(a + 1, a + n + 1);
    find();
    if (ans1 == 0x3f3f3f3f) puts("-1");
    else cout << ans1 << " " << ans2 << endl;
    return 0;
}

正解:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>

using namespace std;

#define FILEIN(s) freopen(s".in", "r", stdin)
#define FILEOUT(s) freopen(s".out", "w", stdout)

inline int read(void) {
    int x = 0, f = 1; char ch = getchar();
    while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
    return f * x;
}

const int maxn = 200005, maxa = 2000005;

int n, a[maxn];
int val[maxa];

vector<int>factor[maxn];

inline void get(int x, int id) {
    for (int i = 2; i * i <= x; ++ i) {
        if (x % i == 0) {
            factor[id].push_back(i);
            while (x % i == 0) x /= i;
        }
    }
    if (x > 1) factor[id].push_back(x);
}

inline int call(int x, int i) { return (x >> i) & 1; }

inline void divide(int id, int k) {
    for (int s = 1; s < (1 << factor[id].size()); ++ s) {
        int x = 1;
        for (int i = 0; i < (int)factor[id].size(); ++ i) {
            if (call(s, i)) x *= factor[id][i];
        }
        val[x] += k;
    }
}

inline int count(int s) {
    int res = 0;
    while (s) {
        res += (s & 1);
        s >>= 1;
    }
    return res;
}

int gcd(int a, int b) {
    if (b == 0) return a;
    return gcd(b, a % b);
}

inline void solve(int i) {
    for (int j = i + 1; j <= n; ++ j) {
        if (gcd(a[i], a[j]) == 1) {
            printf("%d %d\n", a[i], a[j]);
            return;
        }
    }
}

int main() {
    FILEIN("cprime"); FILEOUT("cprime");

    n = read();
    for (int i = 1; i <= n; ++ i) a[i] = read();
    sort(a + 1, a + n + 1);

    for (int i = 1; i <= n; ++ i) {
        int x = a[i];
        get(x, i);
        divide(i, 1);
    }

    for (int id = 1; id < n; ++ id) {
        divide(id, -1);
        int sum = 0;
        for (int s = 1; s < (1 << factor[id].size()); ++ s) {
            int y = 1;
            for (int i = 0; i < (int)factor[id].size(); ++ i) {
                if (call(s, i)) y *= factor[id][i];
            }
            if (count(s) & 1)  // count(s) is an odd number
                sum += val[y];
            else // count(s) is an even number
                sum -= val[y];
        }
        if (sum != n - id) { solve(id); return 0; }
    }
    puts("-1");
    return 0;
}
posted @ 2021-05-05 21:55  蓝田日暖玉生烟  阅读(247)  评论(0编辑  收藏  举报