#1947 道路 || CF1214F Employment
不难观察到一个性质:可以找到一条边 \((i,i+1)\),满足存在一个最优解,其所有匹配的路径不经过这条边,称之为分界线。可以调整证明。
如果我们已知了分界线,不妨设为 \((m,1)\)。那么最小权匹配就是类似括号匹配,贪心扫一遍即可。
这个不是很好优化,考虑对每条边算贡献。不妨令两类点的权值为 \(1\) 和 \(-1\),那么一条边 \((i,i+1)\) 的贡献就是 \(|pre_i|\)。其中 \(pre\) 为前缀和数组。
对于分界线不固定的情况,不妨设从 \((i-1,i)\) 分开,那么一条边 \((j,j+1)\) 的贡献就是 \(|pre_j-pre_i|\)。
考虑写出贡献和的式子 \(\sum\limits_{j=1}^m|pre_j-pre_i|\),最小化该式只需要取 \(pre_i\) 为所有 \(pre\) 的中位数即可。\(pre\) 的值域不大,可以开个桶扫一遍。
瓶颈在于求出 \(pre\) 需要排序,用 bitset
优化可以做到 \(O(n+\dfrac{m}{w})\)。
code
#include <bits/stdc++.h>
using namespace std;
void file() {
freopen("1.in", "r", stdin);
freopen("1.out", "w", stdout);
}
using ll = long long;
namespace IO {
char buf[1 << 20], *p1, *p2;
#define getchar() ((p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 20, stdin), p1 == p2)) ? 0 : *p1++)
int read() {
char ch = getchar();
int x = 0, f = 1;
while (ch < '0' || ch > '9') {
if (ch == '-')
f = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + (ch ^ 48), ch = getchar();
return x * f;
}
}
using IO::read;
const ll kInfl = 1e18;
const int kN = 2e7 + 5, kM = 5e7 + 5;
int n, m;
array<int, kN> a, ord, sum;
bitset<kM> buc;
void Sort(int *a) {
for(int i = 1; i <= n; i++) buc[a[i]] = 1;
for(int i = 1, p = 0; i <= n; i++)
a[i] = p = buc._Find_next(p), buc[p] = 0;
}
void calc() {
int tot = 0, p = 1, q = n + 1;
for(; (p <= n) && (q <= 2 * n); )
ord[++tot] = ((a[p] < a[q]) ? p++ : q++);
for(; p <= n; ) ord[++tot] = p++;
for(; q <= 2 * n; ) ord[++tot] = q++;
ord[++tot] = ord[1];
for(int i = 1, pre = n; i <= 2 * n; i++) {
int x = ord[i], y = ord[i + 1];
int len = a[y] - a[x];
if(len < 0) len += m;
sum[pre += ((x <= n) ? 1 : -1)] += len;
}
}
void clear() {
fill_n(sum.data(), 2 * n + 1, 0);
}
void solve() {
m = read(), n = read(), clear();
for(int i = 1; i <= 2 * n; i++) a[i] = read();
Sort(a.data()), Sort(a.data() + n), calc();
int p = 0, cur = 0;
for(; (cur += sum[p]) * 2 < m; p++) ;
ll ans = 0;
for(int i = 1, pre = n - p; i <= 2 * n; i++) {
int x = ord[i], y = ord[i + 1];
int len = a[y] - a[x];
if(len < 0) len += m;
ans += (ll)len * abs(pre += ((x <= n) ? 1 : -1));
}
cout << ans << "\n";
}
int main() {
// file();
for(int T = read(); T--; solve()) ;
return 0;
}