CF1656F Parametric MST 题解
为了便于解题,先对 \(a\) 数组从小到大进行排序。
首先,根据定义可以得出总价值的表达式:
接着,我们需要发现一个比较重要的性质:
- \(w_{i,j}(t)=a_ia_j+t(a_i+a_j)=(a_i+t)(a_j+t)-t^2\)
也就是说,如果固定一个 \(t\),那么 \(t^2\) 就是定值,可以暂不考虑;\(\forall 1\le i\le n\),如果 \(a_i+t\) 是正数,就向结点 \(1\) 连边以最小化该点的贡献;如果 \(a_i+t\) 是负数,就向结点 \(n\) 连边(如果 \(a_i+t=0\) 那么向哪个点连边都一样)。最后去掉 \(1\) 与 \(n\) 之间的重边以及若干自环,即可构造出正好有 \(n-1\) 条边的最小生成树。
现在来考虑一下无解的情况:
-
如果 \(\forall 1\le i\le n,a_i+t>0\),那么除了 \(1\) 以外的所有结点向 \(1\) 连边,有
\[\begin{aligned} W&=\sum\limits_{(u,v)\in E}a_ua_v+t\sum\limits_{(u,v)\in E}(a_u+a_v)\\ &=a_1\sum\limits_{i=2}^na_i+t\left[(n-1)a_1+\sum\limits_{i=2}^na_i\right]\\ \end{aligned} \]可以发现,\(W\) 是关于 \(t\) 的一次函数。由于满足 \(\forall 1\le i\le n,a_i+t>0\),所以如果一次项系数 \((n-1)a_1+\sum\limits_{i=2}^na_i>0\),那么当 \(t\rightarrow+\infty\) 时 \(W\rightarrow+\infty\),该函数不存在最大值。
-
如果 \(\forall 1\le i\le n,a_i+t<0\),同理有:
\[\begin{aligned} W&=a_n\sum\limits_{i=1}^{n-1}a_i+t\left[(n-1)a_n+\sum\limits_{i=1}^{n-1}a_i\right]\\ \end{aligned} \]类似的,如果 \((n-1)a_n+\sum\limits_{i=1}^{n-1}a_i<0\),当 \(t\rightarrow-\infty\) 时,\(W\rightarrow+\infty\),不存在最大值。
综上,如果 \((n-1)a_1+\sum\limits_{i=2}^na_i>0\) 或 \((n-1)a_n+\sum\limits_{i=1}^{n-1}a_i<0\),边权和不存在最大值,直接输出 INF
即可。
现在来考虑有解时如何寻找解。这时我们就需要引入一个新的性质:
-
如果 \(W\) 能取到最大值,\(-t\) 的值一定是 \(a\) 数组中的其中一个元素的值。
证明:
容易得知当 \(W\) 取到最值时,\(a_1\le -t\le a_n\)(否则函数不收敛)。
当 \(-t\in[a_i,a_{i+1}](1\le i<n)\) 时,最优的连边方式之一是将所有的 \(1\le j\le i\) 向结点 \(n\) 连边(因为这些 \(j\) 满足 \(a_j+t\le 0\)),其他结点向 \(1\) 连边。所以我们可以得到 \(W\) 的表达式:
\[\begin{aligned} W&=a_n\sum\limits_{j=2}^ia_j+t\left[(i-1)a_n+\sum\limits_{j=2}^ia_j\right]+a_1\sum\limits_{j=i+1}^na_j+t\left[(n-i)a_1+\sum_{j=i+1}^na_j\right] \end{aligned} \](注意到这里 \(1\) 和 \(n\) 之间的边仅仅被连了一次,所以最终不用考虑重边的影响。)
在这里因为 \(i\) 确定,所以 \(W\) 还是关于 \(t\) 的一次函数,最值必然在 \(-t=a_i\) 或 \(-t=a_{i+1}\) 时取到。命题得证。
接下来只用枚举 \(i=1,2,\ldots,n\),然后令 \(t=-a_i\),将 \(t\) 带入证明中的那个表达式算出 \(W\) 的值。在所有的 \(W\) 中找到 \(W_{\max}\) 并输出即可。
中间那一些求和的式子可以使用前缀和维护。
放代码:
#include<bits/stdc++.h>
#define int long long
using namespace std;
main(){
ios::sync_with_stdio(false);
int t; cin>>t;
while(t--){
int n,c=LLONG_MIN; cin>>n;
vector<int> a(n),s;
for(auto &i:a)cin>>i;
sort(a.begin(),a.end()); // 排序
partial_sum(a.begin(),a.end(),back_inserter(s)); // 做前缀和
if(a[0]*(n-2)+s[n-1]>0||a[n-1]*(n-2)+s[n-1]<0)cout<<"INF\n"; // 无解情况
else{
for(int i=0;i<n;i++)
c=max(c,a[0]*(s[n-1]-s[i])-a[i]*(a[0]*(n-i-1)+s[n-1]-s[i])+a[n-1]*(s[i]-s[0])-a[i]*(a[n-1]*i+s[i]-s[0]));
// 带入表达式计算
cout<<c<<endl;
}
}
return 0;
}