zoj3988
zoj3988
题意
如果一个集合 \(\{i,j\}\) 满足 \(i\neq j\) 且 \(a[i]+a[j]\) 是素数,则称之为素数集合。
给出一些数字,这些数字可以组成多种素数集合,从这些集合中最多选择 \(k\) 个集合,问这些集合涉及到的数的数量最大值为多少。
分析
存在匹配关系即 \(a[i]+a[j]\) 是素数,那么 \(i\) \(j\) 就可以连边,要求两个数的和为素数,那么这两个数一定有一奇数一偶数(也有可能两个 \(1\)),将奇数放左边,偶数放右边,建二分图。求一发二分图匹配即可。
要注意的是两个 1 的和也是素数,首先奇数 1 放在最后不急着匹配,如果 \(1\) 的数量大于 \(1\) 的话,先将两个 \(1\) 组合起来。
code
#include<bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 2e6 + 10;
const int MAXN = 3e3 + 5;
int notprime[N];
int n, k;
vector<int> odd, even;
int vis[MAXN], has[MAXN], a[MAXN], b[MAXN];
vector<int> v[MAXN];
int dfs(int x) {
for(int i = 0; i < v[x].size(); i++) {
int to = v[x][i];
if(!vis[to]) {
vis[to] = 1;
if(has[to] == -1 || dfs(has[to])) {
has[to] = x;
return 1;
}
}
}
return 0;
}
int main() {
for(int i = 2; i < N; i++) if(!notprime[i]) {
for(ll j = 1LL * i * i; j < N; j += i) notprime[j] = 1;
}
int T;
scanf("%d", &T);
while(T--) {
odd.clear();
even.clear();
memset(has, -1, sizeof has);
scanf("%d%d", &n, &k);
int cnt = 0;
for(int i = 0; i < n; i++) {
int x;
scanf("%d", &x);
if(x == 1) cnt++;
else if(x & 1) odd.push_back(x);
else even.push_back(x);
}
int cc = cnt;
while(cc--) odd.push_back(1);
int ans = 0;
for(int i = 0; i < odd.size(); i++) {
v[i].clear();
for(int j = 0; j < even.size(); j++) {
if(!notprime[odd[i] + even[j]]) {
v[i].push_back(j);
}
}
memset(vis, 0, sizeof vis);
if(dfs(i) && k) {
ans += 2;
k--;
}
}
memset(a, 0, sizeof a);
memset(b, 0, sizeof b);
for(int i = 0; i < even.size(); i++) {
if(has[i] != -1) {
a[has[i]] = 1;
b[i] = 1;
}
}
int cnt1 = 0;
for(int i = 0; i < odd.size() && k; i++) {
if(!a[i] && odd[i] == 1) cnt1++;
}
while(k && cnt1 > 1) {
cnt1 -= 2;
ans += 2;
k--;
}
int flg = cnt1;
for(int i = 0; i < odd.size() && k; i++) {
if(!a[i] && odd[i] == 1) {
if(flg) {
flg--;
} else {
a[i] = 1;
}
}
}
for(int i = 0; i < odd.size() && k; i++) {
for(int j = 0; j < v[i].size() && k; j++) {
int to = v[i][j];
if(a[i] && !b[to]) { ans++; k--; b[to] = 1; }
else if(!a[i] && b[to]) { ans++; k--; if(odd[i] == 1) cnt1--; a[i] = 1; }
}
}
if(k && cnt > 1 && cnt1) ans++;
printf("%d\n", ans);
}
return 0;
}