给树染色
给出一棵有n个节点的树,第i个点的权值为\(a_i\),根节点为r,现在你需要给所有点染色,一个点能够被染色,当且仅当它的父亲节点已经被染色(当然,根节点除外,可以直接被染色),而且对于一个点i染色的费用为\(\text{第几次染色}\times a_i\),询问最少的染色费用,\(n\leq 1000\)。
解
注意到一个性质,就是当树中点权最大的点q的父亲被染色后,q会马上被染色,否则替换成另外一个点w,那么有\(a_w<a_q\),设这是第\(e\)次染色,显然这两次染色的交换不会影响到e次以前的和\(e+2\)次染色以后的染色,于是知道q先染色的对答案的影响为\(e\times a_q+(e+1)\times a_w\),w先染色的对答案的影响为\(e\times a_w+a_q(e+1)\),后者与前者做差,有\(a_q-a_w>0\),于是q先染色会更加优秀。
再琢磨一下这个性质,q的父亲被染色了,那么q就会立刻被染色,这类似于两个点被强制捆绑在了一起,不妨进一步研究捆绑的性质,猜测捆绑以后的点权,设有点\(x,y,z\),设y,z已经被捆绑在了一起,那么对答案的影响只有两种可能(因为前面的染色次数,对大小关系的比较没有影响,第一段已经间接说明了这一点,一下全部省去),
同样想着构造y,z同样的形式构成一个整体(因为最后点权得是一个固定的数字)和\(x+2x+3x+...\)(因为这样才能给出显然的证明)。
因为要构造y,z一个整体,想到两边同时减z,有
要构造\(x+2x+3x+...\),就想到同时加上一个\(a_x\),再除以2以后有
这显然告诉我们y,z可以合并成一个点权为\(\frac{a_y+a_z}{2}\)的点,进一步推广,假设有u个点v被捆绑在一起,不妨假设这些点为\(1,2,...,u\),其中数字的顺序就是染色的顺序,再与点x考虑,只有两种情况
同样加上\(-a_2-2a_3...-(u-1)a_u+(u-1)x\),再除以u,有
显然这就代表了\(1,2,,,..u\)这些点构成了一个整体,点权为\(\frac{\sum_{i=1}^ua_i}{u}\)。
于是总上,我们有做法,每次在树中找到点权最大的点,然后将它与它的父亲合并,更改一下点权,并趁机记录合并后点中的染色的顺序,最后树只会剩下一个点,那么这个点的染色顺序也就是所求,依次为依据手动模拟一下就可算出答案,时间复杂度\(O(n^2)\)。
后话:因为每次寻找到一个点权最大的点,这好像可以用优先队列,但是需要和父亲合并,也就是要支持动态修改,于是萎了,然后考虑用平衡树,需要找到父亲很麻烦,而且点权都相同的话,查找可以到\(O(n)\),故不想写了,如果各位像sxr一样强的人有办法,请联系我。
参考代码:
#include <iostream>
#include <cstdio>
#define il inline
#define ri register
#define Size 1500
#define swap(x,y) x^=y^=x^=y
using namespace std;
struct point{
point *next;int to;
}*pt,*head[Size];
int pa[Size],a[Size],ans[Size],
tot;
void dfs(int);
il void read(int&),link(int,int);
int main(){
int n,r,gzy(0);read(n),read(r);
for(int i(1);i<=n;++i)read(a[i]);
for(int i(1),u,v;i<n;++i)
read(u),read(v),link(u,v),pa[v]=u;dfs(r);
for(int i(1),j;i<=n;++i)
for(j=1;j<n;++j)
if(a[ans[j]]<a[ans[j+1]]&&pa[ans[j+1]]!=ans[j])
swap(ans[j],ans[j+1]);
for(int i(1);i<=n;++i)
gzy+=a[ans[i]]*i;printf("%d",gzy);
return 0;
}
void dfs(int x){ans[++tot]=x;
for(point *i(head[x]);
i!=NULL;i=i->next)dfs(i->to);
}
il void link(int u,int v){
pt=new point,pt->to=v;;
pt->next=head[u],head[u]=pt;
}
il void read(int &x){
x^=x;ri char c;while(c=getchar(),c<'0'||c>'9');
while(c>='0'&&c<='9')x=(x<<1)+(x<<3)+(c^48),c=getchar();
}