最小乘积生成树的另类做法

最小乘积生成树是最小生成树的变形,每条边有一个权值$(a_i, b_i)$, 我们要求一棵生成树,使得$\sum{a_i} \cdot \sum{b_i}$最小。

网上大多数做法是,把解空间看做二维平面上的点,$\sum{a_i}$ $\sum{b_i}$分别看做点的横纵坐标。显然最优解一定是在解集构成的下凸壳上。

这里需要用到另外一种求凸壳的方法。先确定最左边的点$A$,再确定最右边的点$B$,然后找到离直线$BA$最远的点$C$也就是满足$\overrightarrow{BA} \times \overrightarrow{BC}$最大的点$C$,这个点一定在凸壳上,然后递归$AC$, $CB$部分的凸壳。 似乎可以证明下凸壳上的点是$O(M^2)$级别的,总的复杂度是$O(M^3logM)$的。

最近做Petrozavodsk Winter-2014. Moscow SU Tapir Contest的A题的时候学到了另外一种姿势,理论复杂度是一样的,但是实现起来更加方便,在此分享一下。

原题是给定$N$个点,每个点有一个权值$(a_i, b_i)$,选取恰好$K$个点使得$\sum{a_i} \cdot \sum{b_i}$最小。显然这个题可以套用最小乘积生成树求下凸壳的做法,时间复杂度是$O(N^3logN)$ 。 用之后介绍的算法同样可以做到$O(N^3logN)$,而且可以通过一些手段优化做到$O(N^2logN)$。

这个做法可以扩展回最小乘积生成树,不过遗憾的是,扩展回最小乘积生成树问题不能通过本题采用的手段降低复杂度,还是只能做到$O(M^3logM)$。

 

参考知乎的回答

首先是一个非常重要的转化:一定存在某个常数$\lambda(\lambda \geq 0)$,使得$a_i + \lambda b_i$前K小的点集与原问题的最优解相同。(对应到最小乘积生成树问题中,即将边权按$a_i + \lambda b_i$排序做kruskal算法得到的最小生成树就是最小乘积生成树, 以下内容将考虑取$K$个点的问题,请自行对应到最小乘积生成树问题)。

证明: 假设最优解的$\sum{a_i} = A, \sum{b_i} = B$, 取$\lambda = \frac{B}{A}$即可。

取$\lambda = \frac{B}{A}$. 则$a_i + \lambda b_i$前K小的点集的$\sum{a_i} = A', \sum{b_i} = B'$.

则$A' + \lambda B' \leq  A + \lambda B$

若$A' + \lambda B' \lt A + \lambda B$

则$A' \cdot \lambda B' \leq (\frac{A' + \lambda B'}{2})^2 \lt (\frac{A+\lambda B}{2})^2 = \lambda AB $    即 $ A'B' < AB$,于是产生了矛盾。

因此一定有$A' + \lambda B' =  A + \lambda B$.  所以最优解一定对应$\sum{a_i + \lambda b_i}$最小的某个解。

而我们的目的是$\sum{a_i} \cdot \sum{b_i}$最小,因此$\sum{a_i + \lambda b_i}$最小还不够,还要使得$\sum{a_i}$最小或者$\sum{b_i}$最小.

因此,我们只要枚举所有的$\lambda$,然后按$a_i + \lambda b_i$为第一关键字,$a_i$为第二关键字,排序做一遍,再按$a_i + \lambda b_i$为第一关键字,$b_i$为第二关键字做一遍,一定可以得到最优解。

虽然$\lambda$的取值有无穷多种,但实际上我们只要考虑$a_i + \lambda b_i$的排序结果不同的$\lambda$。 即使得$a_i + \lambda b_i = a_j + \lambda b_j$的这些$\lambda$。 再进一步分析,其实我们甚至不需要对每个$\lambda$分别以$a_i$和$b_i$为第二关键字排序做两遍。假设使得$a_i + \lambda b_i = a_j + \lambda b_j$的这些$\lambda$分别是$\lambda_1, \lambda_2, \cdots, \lambda_k$, 我们只需要分别在区间$(0, \lambda_1), (\lambda_1, \lambda_2) \cdots (\lambda_{k-1}, \lambda_k), (\lambda_k, +inf)$各取一个$\lambda$,仅仅按照$a_i + \lambda b_i$排序做一遍即可。

因此对于本题总共只要做$O(N^2)$遍排序,总的复杂度$O(N^3logN)$;对于最小乘积生成树问题,只要做$O(M^2)$遍普通最小生成树即可。

 

对于本题,只要枚举$\lambda$然后排序取前$K$小。如果按从小到大的顺序枚举$\lambda$,每次改变的时候我们可以知道哪些元素的排名上升了,哪些下降了,因此可以直接维护,不必每次重新排序。具体只要每次将那些排名会发生变化的元素重新拿出来排序,然后插回去。可以想象成,数轴上有$N$辆车,起始位置分别是$a_i$,速度分别是$b_i$,而$\lambda$可以看做时间,随着$\lambda$的变大,维护车的排名。  复杂度可以降低到$O(N^2logN)$.

然而,最小乘积生成树问题并不能这样维护,因为kruskal算法不是单纯的取最小的几条边。

 

本题代码:

 

  1 #include <bits/stdc++.h>
  2 using namespace std;
  3 
  4 typedef long long LL;
  5 #define N 1010
  6 int a[N], b[N], rk[N], now[N], n, K;
  7 double val[N];
  8 LL ans = 1e18;
  9 vector<int> best;
 10 const double EPS = 1e-6;
 11 
 12 
 13 struct event
 14 {
 15     double t;
 16     int k1, k2;
 17     bool operator < (const event &o)const
 18     {
 19         return t < o.t;
 20     }
 21 };
 22 vector<event> L;
 23 
 24 
 25 bool cmp(int x, int y) {return val[x] < val[y];}
 26 
 27 int main()
 28 {
 29     //freopen("in.txt", "r", stdin);
 30  
 31     scanf("%d %d", &n, &K);
 32     for (int i = 1; i <= n; ++i)
 33         scanf("%d %d", &a[i], &b[i]), now[i] = i, val[i] = a[i];
 34     sort(now + 1, now + n + 1, cmp);
 35     for (int i = 1; i <= n; ++i) rk[now[i]] = i;
 36     
 37     
 38     LL sa = 0, sb = 0;
 39     for (int i = 1; i <= K; ++i) 
 40     {
 41         sa += a[now[i]], sb += b[now[i]];
 42          best.push_back(now[i]);
 43     }
 44     ans = sa * sb;
 45     
 46     
 47     for (int i = 1; i < n; ++i)
 48     {
 49         for (int j = i + 1; j <= n; ++j)
 50         {
 51             if (b[i] == b[j]) continue;
 52             //a[i] + k*b[i] = a[j] + k * b[j]  
 53             if (1LL * (a[j] - a[i]) * (b[i] - b[j]) >= 0) 
 54                 L.push_back((event){(double)(a[j] - a[i]) / (b[i] - b[j]), i, j});
 55         }
 56     }
 57     sort(L.begin(), L.end()); 
 58     
 59     for (int i = 0, j; i < L.size(); i = j + 1)
 60     {
 61         vector<int> lis;
 62         vector<int> pos;
 63         
 64         j = i;
 65         while (j + 1 < L.size() && fabs(L[j + 1].t - L[i].t) < EPS) ++j;
 66         for (int k = i; k <= j; ++k)
 67         {
 68             lis.push_back(L[k].k1); 
 69             lis.push_back(L[k].k2);
 70         }
 71         sort(lis.begin(), lis.end());
 72         lis.erase(unique(lis.begin(), lis.end()), lis.end());
 73                                        
 74         double lamda = j + 1 == L.size()? L[i].t + 1: (L[j].t + L[j + 1].t) / 2.0;
 75         for (auto x: lis) val[x] = a[x] + lamda * b[x], pos.push_back(rk[x]);
 76         sort(lis.begin(), lis.end(), cmp);
 77         sort(pos.begin(), pos.end());
 78         for (int k = 0; k < lis.size(); ++k)
 79         {
 80             int x = lis[k];
 81             //cout <<  "!! " << x << " " << rk[x] << " " << val[x] << endl; 
 82             if (rk[x] <= K) sa -= a[x], sb -= b[x];
 83             if (pos[k] <= K) sa += a[x], sb += b[x];
 84             rk[x] = pos[k];
 85             now[pos[k]] = x;
 86         }
 87         assert(sa > 0 && sb > 0);
 88         if (sa * sb < ans)
 89         {
 90                ans = sa * sb;
 91                best.clear();
 92                for (int k = 1; k <= K; ++k)
 93                    best.push_back(now[k]);
 94         }
 95     }
 96 
 97     sa = sb = 0;
 98     printf("%lld\n", ans);
 99     sort(best.begin(), best.end());
100     for (int i = 0; i < best.size(); ++i)
101         printf("%d%c", best[i], i + 1 == best.size()? '\n': ' '), sa += a[best[i]], sb += b[best[i]];
102     assert(sa * sb == ans);
103     return 0;
104 }

 

posted @ 2018-11-14 13:04  lzw4896s  阅读(1066)  评论(1编辑  收藏  举报