「BZOJ 2809」「APIO 2012」Dispatching「启发式合并」

题意

给定一个\(1\)为根的树,每个点有\(c,w\)两个属性,你需要从某个点\(u\)子树里选择\(k\)个点,满足选出来的点\(\sum_{i=1}^k w(i)\leq m\),最大化\(k\times c(u)\)

题解

可以启发式合并\(splay\)来做,\(\text{dfs}\)每个点,每次和儿子的\(splay\)合并,就得到了一个维护这个点子树的平衡树。再用这个点的\(w\)(题目中的领导力)乘以子树中最多选多少结点满足\(c\)(薪水)和\(\leq m\)

肯定贪心选,小的先选

然后这玩意我一开始脑抽写了二分\(+kth\)的两个\(\log\)做法,会\(\text{TLE}\)一个点,\(97pts\)

int kth(int u, int k) {
    while(1) {
        if(sz[ch[u][0]] >= k) u = ch[u][0];
        else {
            k -= sz[ch[u][0]] + 1;
            if(k <= 0) break ;
            u = ch[u][1];
        }
    }
    splay(u);
    return u;
}

ll qsum(int &u, int k) {
    u = kth(u, k);
    return sum[ch[u][0]] + w[u];
}
ll query(int &u) {
    int l = 1, r = sz[u], ans = 0;
    while(l <= r) {
        int mid = (l + r) >> 1;
        if(qsum(u, mid) <= m) l = (ans = mid) + 1;
        else r = mid - 1;
    }
    return ans;
}

后来发现直接递归找就行了,十分简单,从左子树往自己往右子树依次判能不能选进去

最后的代码就是这样了qwq:

#include <algorithm>
#include <cstdio>
#include <vector>
using namespace std;

typedef long long ll;

const int N = 2e5 + 10;

int n, m, w[N], c[N];
int fa[N], sz[N], ch[N][2];
vector<int> G[N];
ll ans, sum[N];

int dir(int u) { return ch[fa[u]][1] == u; }
void upd(int u) {
    sum[u] = sum[ch[u][0]] + sum[ch[u][1]] + w[u];
    sz[u] = sz[ch[u][0]] + sz[ch[u][1]] + 1;
}
void rotate(int u) {
    int d = dir(u), f = fa[u];
    if(fa[u] = fa[f]) ch[fa[u]][dir(f)] = u;
    if(ch[f][d] = ch[u][d ^ 1]) fa[ch[f][d]] = f;
    fa[ch[u][d ^ 1] = f] = u;
    upd(f); upd(u);
}
void ins(int &rt, int u, int f = 0) {
    if(!rt) {
        rt = u; fa[u] = f;
        sz[u] = 1; sum[u] = w[u];
        ch[u][0] = ch[u][1] = 0;
        return ;
    }
    ins(ch[rt][w[u] > w[rt]], u, rt);
    upd(rt);
}
void splay(int u) {
    for(; fa[u]; rotate(u)) if(fa[fa[u]])
        rotate(dir(u) == dir(fa[u]) ? fa[u] : u);
}
ll query(int u, ll s) {
    if(!u || !s) return 0;
    if(sum[u] <= s) return sz[u];
    if(sum[ch[u][0]] >= s) return query(ch[u][0], s);
    if(sum[ch[u][0]] + w[u] == s) return sz[ch[u][0]] + 1;
    if(sum[ch[u][0]] + w[u] > s) return sz[ch[u][0]];
    return sz[ch[u][0]] + 1 + query(ch[u][1], s - sum[ch[u][0]] - w[u]);
}

int a[N], l;
void get(int u) {
    if(u) {
        get(ch[u][0]);
        a[++ l] = u;
        get(ch[u][1]);
    }
}
int merge(int u, int v) {
    if(sz[u] < sz[v]) swap(u, v);
    l = 0; get(v); int rt = u;
    for(int i = 1; i <= l; i ++) {
        ins(rt, a[i]);
        splay(rt = a[i]);
    }
    return rt;
}

int dfs(int u) {
    int rt = u; sz[u] = 1; sum[u] = w[u];
    fa[u] = ch[u][0] = ch[u][1] = 0;
    for(int i = 0; i < G[u].size(); i ++)
        rt = merge(rt, dfs(G[u][i]));
    ans = max(ans, c[u] * query(rt));
    return rt;
}

int main() {
    scanf("%d%d", &n, &m);
    for(int f, i = 1; i <= n; i ++) {
        scanf("%d%d%d", &f, &w[i], &c[i]);
        G[f].push_back(i);
    }
    dfs(1);
    printf("%lld\n", ans);
    return 0;
}

posted @ 2019-02-14 21:31  hfhongzy  阅读(212)  评论(0编辑  收藏  举报