HDU 1055 Color a Tree
题目:Color a Tree
链接:http://acm.hdu.edu.cn/showproblem.php?pid=1055
题意:给一棵树,要给树上每一个结点染色,第i 个结点染色需要代价为:t * w[i] (t 表示i 结点是第几个染色的),还有一个前提是:要给i 结点染色,必须先给i 结点的父结点染色(根结点随便染)。
思路:
贪心。
假定当前未染色的结点中权值最大的是A结点,如果A结点上方都染完色,那么现在一定是对A结点进行染色。
理解如下:
假设在上述情况下先对其他结点染色,在t 时刻后才轮到A,那么到t 时刻为止,消耗代价为:t1 * w1 + t2 * w2 + ... + tn * wn + t * wA,因为A在初始就具备了染色条件,并且 t = tn + 1,wA > wn,所以A可以往前挪一位,即t1 * w1 + t2 * w2 + ... + tn * wA + t * wn 必然更优,这点很好理解(wa 可以分成 wn + x),而继续上面的操作,最终wa 可以挪到最前面,所以上述情况下,一定是对A结点进行染色。
现在问题就是当A结点不是根结点的时候(即A的父结点还没有染色),可以想象,如果等会A的父结点染完色,那么接下来一定是轮到A了(A的父结点染色后,A就是第一种情况了),既然他们的染色顺序一定相邻,所以两个结点可以结合在一起。把他们看作一个新的结点,而新结点的权值可以用他们的平均值代替,假设等会在t时刻轮到他们染色,那么消耗代价本来为t * wfa + (t + 1) * wa,所以ans 可以先加上wa,那么等会遇到这个结点时,就可以ans += t * (w[fa] + w[a]),然后t+=2。
因此,结点保存3个值:sum(w总和),num(合并的数量),w(找最大标志)。对于为什么可以用平均值作为新结点的权值也很好理解了,因为他后面的消耗代价为t * sum 即 t * wfa + t*wa(注意:加完这个,t要+2,而不是普通的+1),他们之间是不能中断的,那么他们两个放到什么位置就取决于他们的平均值了。 (即tn*wn + (tn+1)*w和 tn * w + (tn+2)*wn 的对比了)
AC代码:
1 #include<stdio.h> 2 #include<vector> 3 using namespace std; 4 #define N 1010 5 #define LL long long 6 struct Node 7 { 8 int num,sum; 9 double w; 10 Node(){} 11 bool operator < (Node b) const 12 { 13 return w < b.w; 14 } 15 }v[N]; 16 17 int n; 18 int fa[N]; 19 int f[N]; 20 21 int find(int x) 22 { 23 int p=x; 24 while(p!=fa[p]) p=fa[p]; 25 int q; 26 while(fa[x]!=p) 27 { 28 q=fa[x]; 29 fa[x]=p; 30 x=q; 31 } 32 return p; 33 } 34 35 bool vis[N]; 36 int find() 37 { 38 int p = -1; 39 double maxt = 0; 40 for(int i=1;i<=n;i++) 41 { 42 if(vis[i]==1) continue; 43 if(v[i].w > maxt) 44 { 45 maxt = v[i].w; 46 p = i; 47 } 48 } 49 return p; 50 } 51 52 int main() 53 { 54 int r,x,y; 55 while(~scanf("%d%d",&n,&r)) 56 { 57 if(n==0 && r==0) break; 58 LL ans = 0; 59 for(int i=1;i<=n;i++) 60 { 61 fa[i] = i; 62 vis[i] = 0; 63 scanf("%d",&v[i].sum); 64 v[i].num = 1; 65 v[i].w = 1.0 * v[i].sum / v[i].num; 66 } 67 vis[0] = 1; 68 fa[0] = 0; 69 f[r]=0; 70 for(int i=1;i<n;i++) 71 { 72 scanf("%d%d",&x,&y); 73 f[y]=x; 74 } 75 int t = 1; 76 while(1) 77 { 78 int p = find(); 79 80 if(p==-1) break; 81 if(vis[find( f[p] ) ] == 1) 82 { 83 ans += t * v[p].sum; 84 t += v[p].num; 85 vis[p] = 1; 86 } 87 else 88 { 89 vis[p]=1; 90 int fp = find(f[p]); 91 ans += v[p].sum * v[fp].num; 92 v[fp].num += v[p].num; 93 v[fp].sum += v[p].sum; 94 v[fp].w = 1.0 * v[fp].sum / v[fp].num; 95 fa[p] = fp; 96 } 97 } 98 printf("%I64d\n",ans); 99 } 100 return 0; 101 } 102 103 /* 104 6 1 105 1 2 1 2 4 1 106 1 2 107 1 3 108 2 4 109 3 5 110 4 6 111 112 39 113 */