最小乘积生成树的另类做法
最小乘积生成树是最小生成树的变形,每条边有一个权值$(a_i, b_i)$, 我们要求一棵生成树,使得$\sum{a_i} \cdot \sum{b_i}$最小。
网上大多数做法是,把解空间看做二维平面上的点,$\sum{a_i}$ $\sum{b_i}$分别看做点的横纵坐标。显然最优解一定是在解集构成的下凸壳上。
这里需要用到另外一种求凸壳的方法。先确定最左边的点$A$,再确定最右边的点$B$,然后找到离直线$BA$最远的点$C$也就是满足$\overrightarrow{BA} \times \overrightarrow{BC}$最大的点$C$,这个点一定在凸壳上,然后递归$AC$, $CB$部分的凸壳。 似乎可以证明下凸壳上的点是$O(M^2)$级别的,总的复杂度是$O(M^3logM)$的。
最近做Petrozavodsk Winter-2014. Moscow SU Tapir Contest的A题的时候学到了另外一种姿势,理论复杂度是一样的,但是实现起来更加方便,在此分享一下。
原题是给定$N$个点,每个点有一个权值$(a_i, b_i)$,选取恰好$K$个点使得$\sum{a_i} \cdot \sum{b_i}$最小。显然这个题可以套用最小乘积生成树求下凸壳的做法,时间复杂度是$O(N^3logN)$ 。 用之后介绍的算法同样可以做到$O(N^3logN)$,而且可以通过一些手段优化做到$O(N^2logN)$。
这个做法可以扩展回最小乘积生成树,不过遗憾的是,扩展回最小乘积生成树问题不能通过本题采用的手段降低复杂度,还是只能做到$O(M^3logM)$。
参考知乎的回答。
首先是一个非常重要的转化:一定存在某个常数$\lambda(\lambda \geq 0)$,使得$a_i + \lambda b_i$前K小的点集与原问题的最优解相同。(对应到最小乘积生成树问题中,即将边权按$a_i + \lambda b_i$排序做kruskal算法得到的最小生成树就是最小乘积生成树, 以下内容将考虑取$K$个点的问题,请自行对应到最小乘积生成树问题)。
证明: 假设最优解的$\sum{a_i} = A, \sum{b_i} = B$, 取$\lambda = \frac{B}{A}$即可。
取$\lambda = \frac{B}{A}$. 则$a_i + \lambda b_i$前K小的点集的$\sum{a_i} = A', \sum{b_i} = B'$.
则$A' + \lambda B' \leq A + \lambda B$
若$A' + \lambda B' \lt A + \lambda B$
则$A' \cdot \lambda B' \leq (\frac{A' + \lambda B'}{2})^2 \lt (\frac{A+\lambda B}{2})^2 = \lambda AB $ 即 $ A'B' < AB$,于是产生了矛盾。
因此一定有$A' + \lambda B' = A + \lambda B$. 所以最优解一定对应$\sum{a_i + \lambda b_i}$最小的某个解。
而我们的目的是$\sum{a_i} \cdot \sum{b_i}$最小,因此$\sum{a_i + \lambda b_i}$最小还不够,还要使得$\sum{a_i}$最小或者$\sum{b_i}$最小.
因此,我们只要枚举所有的$\lambda$,然后按$a_i + \lambda b_i$为第一关键字,$a_i$为第二关键字,排序做一遍,再按$a_i + \lambda b_i$为第一关键字,$b_i$为第二关键字做一遍,一定可以得到最优解。
虽然$\lambda$的取值有无穷多种,但实际上我们只要考虑$a_i + \lambda b_i$的排序结果不同的$\lambda$。 即使得$a_i + \lambda b_i = a_j + \lambda b_j$的这些$\lambda$。 再进一步分析,其实我们甚至不需要对每个$\lambda$分别以$a_i$和$b_i$为第二关键字排序做两遍。假设使得$a_i + \lambda b_i = a_j + \lambda b_j$的这些$\lambda$分别是$\lambda_1, \lambda_2, \cdots, \lambda_k$, 我们只需要分别在区间$(0, \lambda_1), (\lambda_1, \lambda_2) \cdots (\lambda_{k-1}, \lambda_k), (\lambda_k, +inf)$各取一个$\lambda$,仅仅按照$a_i + \lambda b_i$排序做一遍即可。
因此对于本题总共只要做$O(N^2)$遍排序,总的复杂度$O(N^3logN)$;对于最小乘积生成树问题,只要做$O(M^2)$遍普通最小生成树即可。
对于本题,只要枚举$\lambda$然后排序取前$K$小。如果按从小到大的顺序枚举$\lambda$,每次改变的时候我们可以知道哪些元素的排名上升了,哪些下降了,因此可以直接维护,不必每次重新排序。具体只要每次将那些排名会发生变化的元素重新拿出来排序,然后插回去。可以想象成,数轴上有$N$辆车,起始位置分别是$a_i$,速度分别是$b_i$,而$\lambda$可以看做时间,随着$\lambda$的变大,维护车的排名。 复杂度可以降低到$O(N^2logN)$.
然而,最小乘积生成树问题并不能这样维护,因为kruskal算法不是单纯的取最小的几条边。
本题代码:
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 typedef long long LL; 5 #define N 1010 6 int a[N], b[N], rk[N], now[N], n, K; 7 double val[N]; 8 LL ans = 1e18; 9 vector<int> best; 10 const double EPS = 1e-6; 11 12 13 struct event 14 { 15 double t; 16 int k1, k2; 17 bool operator < (const event &o)const 18 { 19 return t < o.t; 20 } 21 }; 22 vector<event> L; 23 24 25 bool cmp(int x, int y) {return val[x] < val[y];} 26 27 int main() 28 { 29 //freopen("in.txt", "r", stdin); 30 31 scanf("%d %d", &n, &K); 32 for (int i = 1; i <= n; ++i) 33 scanf("%d %d", &a[i], &b[i]), now[i] = i, val[i] = a[i]; 34 sort(now + 1, now + n + 1, cmp); 35 for (int i = 1; i <= n; ++i) rk[now[i]] = i; 36 37 38 LL sa = 0, sb = 0; 39 for (int i = 1; i <= K; ++i) 40 { 41 sa += a[now[i]], sb += b[now[i]]; 42 best.push_back(now[i]); 43 } 44 ans = sa * sb; 45 46 47 for (int i = 1; i < n; ++i) 48 { 49 for (int j = i + 1; j <= n; ++j) 50 { 51 if (b[i] == b[j]) continue; 52 //a[i] + k*b[i] = a[j] + k * b[j] 53 if (1LL * (a[j] - a[i]) * (b[i] - b[j]) >= 0) 54 L.push_back((event){(double)(a[j] - a[i]) / (b[i] - b[j]), i, j}); 55 } 56 } 57 sort(L.begin(), L.end()); 58 59 for (int i = 0, j; i < L.size(); i = j + 1) 60 { 61 vector<int> lis; 62 vector<int> pos; 63 64 j = i; 65 while (j + 1 < L.size() && fabs(L[j + 1].t - L[i].t) < EPS) ++j; 66 for (int k = i; k <= j; ++k) 67 { 68 lis.push_back(L[k].k1); 69 lis.push_back(L[k].k2); 70 } 71 sort(lis.begin(), lis.end()); 72 lis.erase(unique(lis.begin(), lis.end()), lis.end()); 73 74 double lamda = j + 1 == L.size()? L[i].t + 1: (L[j].t + L[j + 1].t) / 2.0; 75 for (auto x: lis) val[x] = a[x] + lamda * b[x], pos.push_back(rk[x]); 76 sort(lis.begin(), lis.end(), cmp); 77 sort(pos.begin(), pos.end()); 78 for (int k = 0; k < lis.size(); ++k) 79 { 80 int x = lis[k]; 81 //cout << "!! " << x << " " << rk[x] << " " << val[x] << endl; 82 if (rk[x] <= K) sa -= a[x], sb -= b[x]; 83 if (pos[k] <= K) sa += a[x], sb += b[x]; 84 rk[x] = pos[k]; 85 now[pos[k]] = x; 86 } 87 assert(sa > 0 && sb > 0); 88 if (sa * sb < ans) 89 { 90 ans = sa * sb; 91 best.clear(); 92 for (int k = 1; k <= K; ++k) 93 best.push_back(now[k]); 94 } 95 } 96 97 sa = sb = 0; 98 printf("%lld\n", ans); 99 sort(best.begin(), best.end()); 100 for (int i = 0; i < best.size(); ++i) 101 printf("%d%c", best[i], i + 1 == best.size()? '\n': ' '), sa += a[best[i]], sb += b[best[i]]; 102 assert(sa * sb == ans); 103 return 0; 104 }