Codechef Tree Balancing (TREEBAL)
题目链接
Codechef Tree Balancing (TREEBAL)
题目大意
给定一棵大小为 \(n\) 的树,根节点为节点 \(1\), 第 \(i\) 条边的长度为 \(w_i\)。
你可以进行若干次修改,每次选择一条边 \(i\),给 \(w_i\) 加上 \(1\) 或减去 \(1\),代价为 \(c_i\),注意 \(w_i\) 可以修改成负数。
你的目标是让所有叶子到根节点的距离相同,求达到目标的最小代价,以及任意一组最优方案。
\(1\leq n\leq 2\times 10^5,\;1\leq w_i,c_i\leq 10^6\)
思路
朴素的树 \(dp\) 想法显然是设 \(dp_{u,i}\) 表示子树 \(u\) 内的所有叶子深度(在算上整棵子树和 \(u\) 的父边时)为 \(i\) 时的最小代价,转移为 \(dp_{u,i}=\min_{l}\{|w_e - l|+\sum dp_{v,i-l}\}\)。
注意到绝对值函数是凸函数,而 \(dp\) 转移的过程是一系列凸函数的叠加,所以 \(dp_{u,i}\) 是一个关于 \(i\) 的分段线性凸函数,且每次只会增加 \(O(1)\) 个段,于是考虑转而维护凸包。
维护凸包上的所有拐点的横坐标,每个拐点对应两侧斜率变化了 \(1\),从而一些坐标会出现很多次,于是用 set 维护二元组 \((x,num)\) 表示坐标和出现次数,另外需要知道凸包的斜率和具体数值,维护凸包最右侧的射线的表达式 \(y=ax+b\),即记录 \((a,b)\) 。
叶子:\(dp_u:(w_e,2c_e)\),\(y=c_ex-c_ew_e\)
合并儿子信息:启发式合并所有 set,将 \((a,b)\) 对应相加。
增加 \(u\) 的父边 \(e\):\(e\) 的意义在于当凸包的斜率绝对值 \(>c_e\) 时,这些地方的深度变化用修改 \(e\) 的长度会更优,于是 删去凸包两侧斜率在 \([-c_e,c_e]\) 之外的拐点,在删除右侧拐点 \(x\) (一个)时,\((a,b)\rightarrow (a-1,b+x)\)。同时所 有拐点向右平移 \(w_e\),此时 \((a,b)\rightarrow (a,b-aw_e)\) 。
在求答案时,一个个删去斜率 \(>0\) 的拐点,最终得到最底部斜率为 \(0\) 的线段端点,以及表达式 \(y=0x+b\),于是答案为 \(b\),具体深度任取线段上一值即可。
对于求方案,注意到一条边 \(e\) 只在原斜率绝对值 \(>c_e\) 时才发挥作用,所以我们维护好每个点的凸包的最左和最右两个拐点 \(l_i,r_i\),以及两侧射线的斜率 \(mn_i,mx_i\)。设我们当前求的是子树 \(u\) 以及父边 \(e\) 内,叶子深度 \(=x\) 的方案,若 $ x>r_i$ 且 \(mx_i=c_e\),则 \(w_e\) 要加上 \(x-r_i\),若 \(x<l_i\) 且 \(mn_i=-c_e\),则 \(w_e\) 要减去 \(l_i-x\),否则 \(e\) 对最优方案没啥贡献,不变。然后子树内 \(x\rightarrow x-w_e\) 递归求解即可。
时间复杂度 \(O(n\log^2n)\) 。
Code
#include<iostream>
#include<cstring>
#include<set>
#define mem(a,b) memset(a, b, sizeof(a))
#define rep(i,a,b) for(int i = (a); i <= (b); i++)
#define per(i,b,a) for(int i = (b); i >= (a); i--)
#define N 200021
#define ll long long
#define PLL pair<ll, ll>
#define fr first
#define sc second
using namespace std;
int n;
int head[N], nxt[2*N], to[2*N];
ll w[2*N], c[2*N], len[2*N];
int cnt;
struct Hull{
set<PLL> s;
ll a, b, shift, num;
void adjust(ll c, bool id){
if(id) while(a-num < -c){
ll need = -c-a + num;
PLL p = *s.begin();
num -= min(p.sc, need), s.erase(s.begin());
if(p.sc > need) s.insert({p.fr, p.sc-need});
}
while(a > c){
PLL p = *(--s.end());
ll del = min(a-c, p.sc);
a -= del, b += del * (p.fr+shift), num -= del, s.erase(--s.end());
if(p.sc > del) s.insert({p.fr, p.sc-del});
}
}
} hull[N];
ll lb[N], rb[N], mn[N], mx[N];
void init(){
rep(i,1,n) head[i] = -1; cnt = -1;
rep(i,1,n) hull[i].s.clear(), hull[i].a = hull[i].b = hull[i].shift = hull[i].num = 0;
}
void add_e(int a, int b, int w, int c, bool id){
nxt[++cnt] = head[a], head[a] = cnt, to[cnt] = b;
::w[cnt] = w, ::c[cnt] = c;
if(id) add_e(b, a, w, c, 0);
}
void merge(Hull &a, Hull &b){
if(a.s.size() < b.s.size()) swap(a, b);
for(PLL p : b.s){
ll x = p.fr + b.shift - a.shift;
auto it = a.s.lower_bound({x, 0});
if(it != a.s.end() && it->fr == x) p.sc += it->sc, a.s.erase(it);
a.s.insert({x, p.sc});
}
a.num += b.num;
a.a += b.a, a.b += b.b;
}
void dfs(int x, int e){
bool leaf = true;
for(int i = head[x]; ~i; i = nxt[i]) if(i != (e^1)){
leaf = false;
dfs(to[i], i);
merge(hull[x], hull[to[i]]);
}
if(leaf){
hull[x].s.insert({w[e], 2*c[e]});
hull[x].a = c[e], hull[x].b = -c[e] * w[e];
hull[x].num = 2*c[e];
} else{
hull[x].adjust(c[e], 1), hull[x].shift += w[e];
hull[x].b -= hull[x].a * w[e];
}
lb[x] = hull[x].s.begin()->fr, rb[x] = (--hull[x].s.end())->fr;
lb[x] += hull[x].shift, rb[x] += hull[x].shift;
mn[x] = hull[x].a - hull[x].num, mx[x] = hull[x].a;
}
void solution(int x, int e, ll L){
len[e] = w[e];
if(L > rb[x] && mx[x] == c[e]) len[e] += L-rb[x];
else if(L < lb[x] && mn[x] == -c[e]) len[e] -= lb[x]-L;
len[e^1] = len[e];
for(int i = head[x]; ~i; i = nxt[i]) if(i != (e^1))
solution(to[i], i, L-len[e]);
}
int main(){
ios::sync_with_stdio(false);
int T; cin>>T;
while(T--){
cin>>n;
int u, v, _w, _c; init();
rep(i,1,n-1) cin>>u>>v>>_w>>_c, add_e(u, v, _w, _c, 1);
for(int i = head[1]; ~i; i = nxt[i])
dfs(to[i], i), merge(hull[1], hull[to[i]]);
hull[1].adjust(0, 0);
ll L = hull[1].s.empty() ? 0 : (--hull[1].s.end())->fr + hull[1].shift;
cout<< hull[1].a * L + hull[1].b <<endl;
for(int i = head[1]; ~i; i = nxt[i])
solution(to[i], i, L);
rep(i,0,cnt/2) cout<< len[i*2] <<endl;
}
return 0;
}