【树染色】 Color a tree

传送门

题意

一个树有\(n\)个节点,每个节点\(i\)有一个权值\(A_{i}\),把所有的节点染色,染色的规则是根节点\(R\)可以随时被染色,
其他节点必须父节点被染色才能被染色,每次染色的代价为\(T · A_{i}\) ,\(T\)记录当前是第几个被染色的,求给所有节点染色的最小代价

数据范围

\(\begin{array}{l}1 \leq n \leq 1000 \\ 1 \leq A_{i} \leq 1000\end{array}\)

题解

错误的贪心:在每一个子树的子节点中选择值最大的,
构造一个极端:让一个权值很小的子树根下面子节点权值极大
一个正确的性质:树中除了根外权最大的点,一定会在其父节点被染色后立即染色
即树中节点权值最大的点和其父节点的染色是绑定的,将这两个合并,合并后的新点的权值是二者平均值
即给\(x,y,z\)三点染色,其中\(x,y\)是绑定的,\(x\)是父节点,有两种可能

  • 先染\(x\)\(y\)\(z\),代价 \(x+2y+3z\)
  • \(z\)\(x\)\(y\),代价\(z+2x+3y\)

\(x+2y+3z < z+2x+3y\) 化简后得到\(z < \frac{x+y}{2}\),即当均值大的时候后染色一定代价更大
当有两组点\(a_{1}, a_{2}, \dots a_{n}\)\(b_{1}, b_{2}, \ldots b_{m}\)的情况时,组内的点在染色时是相邻的一段。

  • 先染\(a\),\(S_{a b}=\sum_{i=1}^{n} a_{i} * i+\sum_{i=n+1}^{n+m} b_{i} * i\)
  • 先染\(b\),\(S_{b a}=\sum_{i=1}^{m} b_{i} * i+\sum_{i=m+1}^{n+m} a_{i} * i\)

可以得到\(S_{a b}-S_{b a}=n * \sum_{i=1}^{m} b_{i}-m * \sum_{i=1}^{n} a_{i}\),所以\(S_{a b}-S_{b a}<0 \Longleftrightarrow \frac{\sum_{i=1}^{n} a_{i}}{n}<\frac{\sum_{i=1}^{m} b_{i}}{m}\)
即每次找出权值最大的非根节点,将其染色顺序排在紧随父节点之后的位置,然后将该点合并进父节点中,更新父节点的权值。直到将所有点都合并进根节点为止。
总代价的计算方式:

  • 最初所有点各自为一组,总分值是 \(S=\sum_{i=1}^{n} a_{i}\)
  • 接下来每次会将两组点合并,将其中一组点接在另一组点的后面。比如两组点分别是 \(x_{i}\)\(y_{i}\),我们将 \(y_{i}\) 接在 \(x_{i}\) 之后,则 \(y_{i}\) 中每个点所乘的系数均会增加一个相同的偏移量,这个偏移量就是 \(x_{i}\) 中点的个数,假设是 \(k\),则合并之后,总的权值直接加上 \(k * \sum y_{i}\) 即可;

Code

#include<bits/stdc++.h>
using namespace std;
#define rep(i,a,n) for(int i=a;i<n;i++)
#define per(i,a,n) for(int i=n-1;i>=a;i--)
#define fi first 
#define se second 
#define ll long long
#define pb push_back
typedef pair<int,int> pii;
const int N=1010;
int n,r;
struct node{
    int v,sz,fa;
    double avg;
}e[N];
int find_max(){ 
   double mx=-1;
   int res;
   for(int i = 1; i <= n; i++){
      if(i != r && mx < e[i].avg){
         mx=e[i].avg;
         res=i;
      }
   }
   return res;
}
int main(){
    scanf("%d%d",&n,&r);
    int ans=0;
    rep(i,1,n+1) {
        scanf("%d",&e[i].v);
        e[i].sz=1;
        e[i].avg=e[i].v;
        ans+=e[i].v;
    }
    rep(i,0,n-1){
        int x,y;
        scanf("%d%d",&x,&y);
        e[y].fa=x;
    }
    rep(i,0,n-1){
        int p=find_max();
        int p_fa=e[p].fa;
        ans+=e[p].v*e[p_fa].sz;
        e[p].avg=-1;

        rep(j,1,n+1) if(e[j].fa==p) e[j].fa=p_fa;
        e[p_fa].v += e[p].v;
        e[p_fa].sz += e[p].sz;
        e[p_fa].avg=(double) e[p_fa].v/e[p_fa].sz;
    }
    printf("%d\n",ans);
}

posted @ 2020-06-29 21:22  Hyx'  阅读(516)  评论(1编辑  收藏  举报