CF932F Escape Through Leaf 斜率优化、启发式合并

传送门


\(DP\)

\(f_i\)表示第\(i\)个节点的答案,\(S_i\)表示\(i\)的子节点集合,那么转移方程为\(f_i = \min\limits_{j \in S_i} \{a_i \times b_j + f_j\}\)

这是一个很明显的斜率优化式子,斜率为\(b_j\),截距为\(f_j\),自变量为\(a_i\)。考虑到斜率没有单调性,所以使用set维护凸包。

使用set维护凸包比较简单。一条直线插入时,先判断这条线段在当前凸包中是否合法,然后不断把两边不合法的直线删去。具体的实现看下面代码的insert函数吧

然后如何将一个点所有儿子的set合并为自己的set呢?使用启发式合并来保证复杂度。

最后,知道了当前点的所有儿子节点的凸包,如何计算当前点的答案?正解是二分斜率,因为你没法在set上直接二分直线。

总复杂度为\(O(nlog^2n)\)

还可以通过dfn序将树上问题转化为序列问题,对一个点有贡献的是一段连续区间,就可以用CDQ分治解决。

#include<bits/stdc++.h>
#define int long long
#define ld long double
//This code is written by Itst
using namespace std;

inline int read(){
    int a = 0;
    bool f = 0;
    char c = getchar();
    while(c != EOF && !isdigit(c)){
        if(c == '-')
            f = 1;
        c = getchar();
    }
    while(c != EOF && isdigit(c)){
        a = (a << 3) + (a << 1) + (c ^ '0');
        c = getchar();
    }
    return f ? -a : a;
}

const int MAXN = 100010;
struct Edge{
    int end , upEd;
}Ed[MAXN << 1];
struct node{
    int k , b;
}now;
set < node > s[MAXN];
long long sum[MAXN] , a[MAXN] , b[MAXN] , minN[MAXN] ,  fa[MAXN] , size[MAXN] , head[MAXN] , N , cntEd;

bool operator <(node a , node b){
    return a.k < b.k || (a.k == b.k && a.b < b.b);
}

inline void addEd(int a , int b){
    Ed[++cntEd].end = b;
    Ed[cntEd].upEd = head[a];
    head[a] = cntEd;
}

void dfs(int now , int f){
    fa[now] = f;
    size[now] = 1;
    for(int i = head[now] ; i ; i = Ed[i].upEd)
        if(Ed[i].end != f){
            dfs(Ed[i].end , now);
            size[now] += size[Ed[i].end];
        }
    if(size[now] == 1)
        minN[now] = 0;
}

inline ld calcNode(node a , node b){
    return (a.b - b.b) / (ld)(b.k - a.k);
}

inline void insert(int now , int k , int b){//插入一条斜率为k、截距为b的直线
    node x , l , r;
    x.k = k;
    x.b = b;
    set < node > :: iterator it = s[now].lower_bound(x);
    if(it != s[now].end() && (*it).k == k){//判断是否存在斜率相同的直线
        s[now].erase(it);
        it = s[now].lower_bound(x);
    }
    else
        if(it != s[now].begin() && (*--it).k == k)
            return;
    it = s[now].lower_bound(x);
    if(it != s[now].begin() && it != s[now].end()){
        l = *it , r = *(--it);
        if(calcNode(r , x) < calcNode(l , r) && calcNode(l , r) < calcNode(l , x))//判断这一条直线是否能够被加入当前凸包
            return;
        ++it;
    }
    while(1){//向两边删去不合法直线
        it = s[now].lower_bound(x);
        if(it == s[now].end())
            break;
        l = *it;
        if(++it == s[now].end())
            break;
        r = *it;
        if(calcNode(l , r) > calcNode(l , x) && calcNode(l , r) > calcNode(r , x))
            s[now].erase(--it);
        else
            break;
    }	
    while(1){
        it = s[now].lower_bound(x);
        if(it == s[now].begin())
            break;
        l = *(--it);
        if(it == s[now].begin())
            break;
        r = *(--it);
        if(calcNode(l , r) < calcNode(l , x) && calcNode(l , r) < calcNode(r , x))
            s[now].erase(++it);
        else
            break;
    }
    s[now].insert(x);
}

inline void merge(int root , int now){
    bool f = 0;
    if(size[root] < size[now]){
        swap(root , now);
        f = 1;
    }
    for(set < node > :: iterator it = s[now].begin() ; it != s[now].end() ; ++it)
        insert(root , (*it).k , (*it).b);
    if(f){
        s[now] = s[root];
        swap(now , root);
    }
    s[now].clear();
}

void dsu(int now){//别被误导了,这个不是dsu on tree
    if(size[now] == 1){
        insert(now , b[now] , minN[now]);
        return;
    }
    for(int i = head[now] ; i ; i = Ed[i].upEd)
        if(Ed[i].end != fa[now]){
            dsu(Ed[i].end);
            merge(now , Ed[i].end);
        }
    int l = -1e5 - 1 , r = 1e5 + 1;
    set < node > :: iterator it;
    node L , R;
    while(l < r){
        int mid = l + r >> 1;
        it = s[now].lower_bound((node){mid , (long long)-1e15});
        if(it == s[now].begin()){
            l = mid + 1;
            continue;
        }
        if(it == s[now].end()){
            r = mid;
            continue;
        }
        L = *it;
        R = *(--it);
        if(R.k * a[now] + R.b >= L.k * a[now] + L.b)
            l = mid + 1;
        else
            r = mid;
    }
    L = *s[now].lower_bound((node){l - 1 , (long long)-1e15});
    minN[now] = L.k * a[now] + L.b;
    insert(now , b[now] , minN[now]);
}

signed main(){
    N = read();
    memset(minN , 0x3f , sizeof(minN));
    for(int i = 1 ; i <= N ; i++)
        a[i] = read();
    for(int i = 1 ; i <= N ; i++)
        b[i] = read();
    for(int i = 1 ; i < N ; i++){
        int a = read() , b = read();
        addEd(a , b);
        addEd(b , a);
    }
    dfs(1 , 0);
    dsu(1);
    for(int i = 1 ; i <= N ; i++)
        cout << minN[i] << ' ';
    return 0;
}
posted @ 2019-01-28 08:18  cjoier_Itst  阅读(425)  评论(0编辑  收藏  举报