【codeforces 348B】Apple Tree
【题目链接】:http://codeforces.com/problemset/problem/348/B
【题意】
给你一棵树;
叶子节点有权值;
对于非叶子节点;
它的权值是以这个节点为根的子树上的叶子节点的权值的和;
定义一棵树是平衡的,当且仅当,每个节点的所有直系儿子的权值都相等;
问你要使得这棵树平衡,最少需要删除掉多少叶子节点上的权值;
【题解】
在第一个dfs里面求出d[i]和s[i];
设d[i]表示以i为根的子树要平衡的话最少需要多少权值(注意这里不是说最少要删去多少权值,而是最少需要多少权值),对于叶子节点d[i]=1,非叶子节点的话,d[i] = k*lcm(d[j]),j是i的儿子节点,k是i的直系儿子节点个数;
因为只有按照这样的分法,才能保证都能平均地分下来;也就是说如果最后i节点有权值的话,他一定得是d[i]的倍数;
(lcm是最小公倍数);
s[i]是以i为根的子树里面叶子节点的权值和;
在第二个dfs里面算需要删掉多少;
对于i节点;
先求出t = d[i]/k;这里k同样是i奇点的直系儿子节点的个数;
这样,我们就先算出了,每个儿子节点的权值都应该是t的倍数才对;
然后我们求出所有儿子节点j里面,s的值最小的s[j]=mis;
然后让所有的s[j]变成最小的t的倍数,也即t = (mis/t) * t
然后∑s[j] - t*k 就是需要扣除掉的;
注意这里之后要更新每个儿子节点的s值,同时更新i节点的s值;
注意为根节点的时候,len和非和节点的时候的len…,根节点就算len=1也得继续往下做,不能直接退出。。。所以注意叶子节点的判断。
【Number Of WA】
1
【反思】
这道题,注意到每个节点都有一个最小的平衡权值,且注意到都得是这个最小平衡权值的整数倍是关键。
【完整代码】
#include <bits/stdc++.h>
using namespace std;
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define LL long long
#define rep1(i,a,b) for (int i = a;i <= b;i++)
#define rep2(i,a,b) for (int i = a;i >= b;i--)
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define ms(x,y) memset(x,y,sizeof x)
#define Open() freopen("F:\\rush.txt","r",stdin)
#define Close() ios::sync_with_stdio(0)
typedef pair<int,int> pii;
typedef pair<LL,LL> pll;
const int dx[9] = {0,1,-1,0,0,-1,-1,1,1};
const int dy[9] = {0,0,0,-1,1,-1,1,-1,1};
const double pi = acos(-1.0);
const int N = 1e5+100;
const LL oo = 1e18;
LL d[N],s[N],a[N],sum,ans;
int n;
vector <int> G[N];
LL lcm(LL x,LL y){
return (x/__gcd(x,y))*y;
}
void out(){
cout << sum << endl;
exit(0);
}
void dfs1(int x,int fa){
d[x] = 1,s[x] = a[x];
if (G[x].size()==1 && x!=1){
return;
}
int len = G[x].size();
d[x] = 1;
rep1(i,0,len-1){
int y = G[x][i];
if (y==fa) continue;
dfs1(y,x);
d[x] = lcm(d[x],d[y]);
if (d[x]>sum) out();
s[x] += s[y];
}
if (x==1)
d[x] = d[x]*len;
else
d[x] = d[x]*(len-1);
}
void dfs2(int x,int fa){
int len = G[x].size();
if (len==1 && x!=1) return;
LL mis = oo,temp = 0;
rep1(i,0,len-1){
int y = G[x][i];
if (y == fa) continue;
dfs2(y,x);
mis = min(mis,s[y]);
temp+=s[y];
}
LL t;
if (x==1)
t = d[x]/len;
else
t = d[x]/(len-1);
t = (mis/t)*t;
if (x==1)
ans += temp-t*len;
else
ans += temp-t*(len-1);
s[x] = a[x];
rep1(i,0,len-1){
int y = G[x][i];
if (y==fa) continue;
s[y] = t;
s[x]+=s[y];
}
}
int main(){
//Open();
Close();
cin >> n;
rep1(i,1,n){
cin >> a[i];
sum += a[i];
}
rep1(i,1,n-1){
int x,y;
cin >> x >> y;
G[x].pb(y);
G[y].pb(x);
}
dfs1(1,-1);
dfs2(1,-1);
cout << ans << endl;
return 0;
}