Codechef Tree Balancing (TREEBAL)

题目链接

Codechef Tree Balancing (TREEBAL)

题目大意

给定一棵大小为 \(n\) 的树,根节点为节点 \(1\), 第 \(i\) 条边的长度为 \(w_i\)

你可以进行若干次修改,每次选择一条边 \(i\),给 \(w_i\) 加上 \(1\) 或减去 \(1\),代价为 \(c_i\),注意 \(w_i\) 可以修改成负数。

你的目标是让所有叶子到根节点的距离相同,求达到目标的最小代价,以及任意一组最优方案。

\(1\leq n\leq 2\times 10^5,\;1\leq w_i,c_i\leq 10^6\)

思路

朴素的树 \(dp\) 想法显然是设 \(dp_{u,i}\) 表示子树 \(u\) 内的所有叶子深度(在算上整棵子树和 \(u\) 的父边时)为 \(i\) 时的最小代价,转移为 \(dp_{u,i}=\min_{l}\{|w_e - l|+\sum dp_{v,i-l}\}\)

注意到绝对值函数是凸函数,而 \(dp\) 转移的过程是一系列凸函数的叠加,所以 \(dp_{u,i}\) 是一个关于 \(i\) 的分段线性凸函数,且每次只会增加 \(O(1)\) 个段,于是考虑转而维护凸包。

维护凸包上的所有拐点的横坐标,每个拐点对应两侧斜率变化了 \(1\),从而一些坐标会出现很多次,于是用 set 维护二元组 \((x,num)\) 表示坐标和出现次数,另外需要知道凸包的斜率和具体数值,维护凸包最右侧的射线的表达式 \(y=ax+b\),即记录 \((a,b)\)

叶子:\(dp_u:(w_e,2c_e)\)\(y=c_ex-c_ew_e\)

合并儿子信息:启发式合并所有 set,将 \((a,b)\) 对应相加。

增加 \(u\) 的父边 \(e\)\(e\) 的意义在于当凸包的斜率绝对值 \(>c_e\) 时,这些地方的深度变化用修改 \(e\) 的长度会更优,于是 删去凸包两侧斜率在 \([-c_e,c_e]\) 之外的拐点,在删除右侧拐点 \(x\) (一个)时,\((a,b)\rightarrow (a-1,b+x)\)。同时所 有拐点向右平移 \(w_e\),此时 \((a,b)\rightarrow (a,b-aw_e)\)

在求答案时,一个个删去斜率 \(>0\) 的拐点,最终得到最底部斜率为 \(0\) 的线段端点,以及表达式 \(y=0x+b\),于是答案为 \(b\),具体深度任取线段上一值即可。

对于求方案,注意到一条边 \(e\) 只在原斜率绝对值 \(>c_e\) 时才发挥作用,所以我们维护好每个点的凸包的最左和最右两个拐点 \(l_i,r_i\),以及两侧射线的斜率 \(mn_i,mx_i\)。设我们当前求的是子树 \(u\) 以及父边 \(e\) 内,叶子深度 \(=x\) 的方案,若 $ x>r_i$ 且 \(mx_i=c_e\),则 \(w_e\) 要加上 \(x-r_i\),若 \(x<l_i\)\(mn_i=-c_e\),则 \(w_e\) 要减去 \(l_i-x\),否则 \(e\) 对最优方案没啥贡献,不变。然后子树内 \(x\rightarrow x-w_e\) 递归求解即可。

时间复杂度 \(O(n\log^2n)\)

Code

#include<iostream>
#include<cstring>
#include<set>
#define mem(a,b) memset(a, b, sizeof(a))
#define rep(i,a,b) for(int i = (a); i <= (b); i++)
#define per(i,b,a) for(int i = (b); i >= (a); i--)
#define N 200021
#define ll long long
#define PLL pair<ll, ll>
#define fr first
#define sc second
using namespace std;

int n;
int head[N], nxt[2*N], to[2*N];
ll w[2*N], c[2*N], len[2*N];
int cnt;

struct Hull{
    set<PLL> s;
    ll a, b, shift, num;

    void adjust(ll c, bool id){
        if(id) while(a-num < -c){
            ll need = -c-a + num;
            PLL p = *s.begin();
            num -= min(p.sc, need), s.erase(s.begin());
            if(p.sc > need) s.insert({p.fr, p.sc-need});
        }
        while(a > c){
            PLL p = *(--s.end());
            ll del = min(a-c, p.sc);
            a -= del, b += del * (p.fr+shift), num -= del, s.erase(--s.end());
            if(p.sc > del) s.insert({p.fr, p.sc-del});
        }
    }
} hull[N];
ll lb[N], rb[N], mn[N], mx[N];

void init(){ 
    rep(i,1,n) head[i] = -1; cnt = -1; 
    rep(i,1,n) hull[i].s.clear(), hull[i].a = hull[i].b = hull[i].shift = hull[i].num = 0;
}
void add_e(int a, int b, int w, int c, bool id){
    nxt[++cnt] = head[a], head[a] = cnt, to[cnt] = b;
    ::w[cnt] = w, ::c[cnt] = c;
    if(id) add_e(b, a, w, c, 0);
}

void merge(Hull &a, Hull &b){
    if(a.s.size() < b.s.size()) swap(a, b);
    for(PLL p : b.s){
        ll x = p.fr + b.shift - a.shift;
        auto it = a.s.lower_bound({x, 0});
        if(it != a.s.end() && it->fr == x) p.sc += it->sc, a.s.erase(it);
        a.s.insert({x, p.sc});
    }
    a.num += b.num;
    a.a += b.a, a.b += b.b;
}

void dfs(int x, int e){
    bool leaf = true;
    for(int i = head[x]; ~i; i = nxt[i]) if(i != (e^1)){
        leaf = false;
        dfs(to[i], i);
        merge(hull[x], hull[to[i]]);
    }
    if(leaf){
        hull[x].s.insert({w[e], 2*c[e]});
        hull[x].a = c[e], hull[x].b = -c[e] * w[e];
        hull[x].num = 2*c[e];
    } else{
        hull[x].adjust(c[e], 1), hull[x].shift += w[e];
        hull[x].b -= hull[x].a * w[e];
    }

    lb[x] = hull[x].s.begin()->fr, rb[x] = (--hull[x].s.end())->fr;
    lb[x] += hull[x].shift, rb[x] += hull[x].shift;
    mn[x] = hull[x].a - hull[x].num, mx[x] = hull[x].a;
}

void solution(int x, int e, ll L){
    len[e] = w[e];
    if(L > rb[x] && mx[x] == c[e]) len[e] += L-rb[x];
    else if(L < lb[x] && mn[x] == -c[e]) len[e] -= lb[x]-L;
    len[e^1] = len[e];

    for(int i = head[x]; ~i; i = nxt[i]) if(i != (e^1))
        solution(to[i], i, L-len[e]);
}

int main(){
    ios::sync_with_stdio(false);
    int T; cin>>T;
    while(T--){
        cin>>n;
        int u, v, _w, _c; init();
        rep(i,1,n-1) cin>>u>>v>>_w>>_c, add_e(u, v, _w, _c, 1);

        for(int i = head[1]; ~i; i = nxt[i])
            dfs(to[i], i), merge(hull[1], hull[to[i]]);

        hull[1].adjust(0, 0);
        ll L = hull[1].s.empty() ? 0 : (--hull[1].s.end())->fr + hull[1].shift;
        cout<< hull[1].a * L + hull[1].b <<endl;

        for(int i = head[1]; ~i; i = nxt[i])
            solution(to[i], i, L);
        rep(i,0,cnt/2) cout<< len[i*2] <<endl;
    }
    return 0;
}
posted @ 2021-12-21 08:54  Neal_lee  阅读(64)  评论(0编辑  收藏  举报