【AtCoder Grand Contest 007E】Shik and Travel [Dfs][二分答案]

Shik and Travel

Time Limit: 50 Sec  Memory Limit: 512 MB

Description

  给定一棵n个点的树,保证一个点出度为2/0。

  遍历一遍,要求每条边被经过两次,第一次从根出发,最后一次到根结束,在叶子节点之间移动。

  移动一次的费用为路径上的边权之和,第一次和最后一次免费,移动的最大费用 最小可以是多少。

Input

  第一行一个n,表示点数。

  之后两个数x, y,若在第 i 行,表示 i+1 -> x 有一条权值为 y 的边。

Output

  输出一个数表示答案。

Sample Input

  7
  1 1
  1 1
  2 1
  2 1
  3 1
  3 1

Sample Output

  4

HINT

  2 < n < 131,072
  0 ≤ y ≤ 131,072

Solution

  问题的本质就是:求一个叶子节点排列,按照排列顺序走,使得两两距离<=K
  因为第一天和最后一天不花费,可以第一天从根走到一个叶子,最后一天从某一叶子走回根。
  我们首先二分答案

  对于子树u维护二元组(a, b),表示存在方案可以 从 与u距离为a的点 出发 然后走到 与u距离为b的点,并且遍历了u中的所有叶子节点
  用个vector存一下即可。显然,若a升序则b要降序,否则是无用状态

  运用Dfs从叶子节点往上推。我们现在考虑如何合并子树u、v的(a, b)。给一棵子树编号(a1, b1),另一棵为(a2, b2)
  我们新二元组的走法应该是 a1->b1, b1->a2, a2->b2 的,
  只要保证 b1->a2 这一条路径 权值和<=K 即可合并成(a1 + (u->fa), b2 + (v->fa))
  显然用(a1, b1)去合并只有一个有用状态:满足b1 + a2 + (u->fa) + (v->fa)<=Ka2尽量大,因为这样b2会尽量小
  枚举size较小的一边二分一下另外一边即可。

  若推到根存在一组二元组即可行。

Code

#include<iostream>
#include<string>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<vector>
using namespace std;
typedef long long s64;

const int ONE = 1000005;
const s64 INF = 1e18;

#define next nxt

int get()
{
        int res = 1, Q = 1; char c;
        while( (c = getchar()) < 48 || c > 57)
            if(c == '-') Q = -1;
        if(Q) res = c - 48;
        while( (c = getchar()) >= 48 && c <= 57)
            res = res * 10 + c - 48;
        return res * Q;
}

struct power
{
        s64 a, b;
        bool operator <(power A) const
        {
            if(A.a != a) return a < A.a;
            return b < A.b;
        }
};
vector <power> A[ONE], R;

int n;
int x, y;
s64 l, r, K;
int next[ONE], first[ONE], go[ONE], w[ONE], tot;
int size[ONE], fat[ONE];

void Add(int u, int v, int z)
{
        next[++tot] = first[u], first[u] = tot, go[tot] = v, w[tot] = z;
        next[++tot] = first[v], first[v] = tot, go[tot] = u, w[tot] = z;
}

int dist[ONE];
int len_x, len_y;
s64 a1, b1, a2, b2;

int Find()
{
        if(len_y == 0) return len_y;
        int l = 0, r = len_y - 1;
        while(l < r - 1)
        {
            int mid = l + r >> 1;
            a2 = A[y][mid].a, b2 = A[y][mid].b;
            if(b1 + a2 + dist[x] + dist[y] <= K) l = mid;
            else r = mid;
        }
        a2 = A[y][r].a, b2 = A[y][r].b; if(b1 + a2 + dist[x] + dist[y] <= K) return r;
        a2 = A[y][l].a, b2 = A[y][l].b; if(b1 + a2 + dist[x] + dist[y] <= K) return l;
        return len_y;
}

void Update(int u)
{
        x = 0, y = 0;
        for(int e = first[u]; e; e = next[e])
            if(go[e] != fat[u])
                if(!x) x = go[e]; else y = go[e];

        if(size[x] > size[y]) swap(x, y);

        len_x = A[x].size(), len_y = A[y].size();

        R.clear();
        for(int i = 0; i < len_x; i++)
        {
            a1 = A[x][i].a, b1 = A[x][i].b;
            if(Find() >= len_y) continue;
            R.push_back((power){a1 + dist[x], b2 + dist[y]});
            R.push_back((power){b2 + dist[y], a1 + dist[x]});
        }


        sort(R.begin(), R.end());
        int len = R.size();
        s64 maxx = INF;
        for(int i = 0; i < len; i++)
            if(R[i].b < maxx)
                A[u].push_back(R[i]), maxx = R[i].b;
}

void Dfs(int u, int father)
{
        size[u] = 1;
        int pd = 0;
        for(int e = first[u]; e; e = next[e])
        {
            int v = go[e];
            if(v == father) continue;
            fat[v] = u, dist[v] = w[e];
            Dfs(v, u);
            size[u] += size[v], pd++;
            if(pd == 2) Update(u);
        }
        if(!pd) A[u].push_back((power){0, 0});
}

int Check()
{
        for(int i = 1; i <= n; i++)
            A[i].clear();
        Dfs(1, 0);
        return A[1].size() > 0;
}

int main()
{
        n = get();
        for(int i = 2; i <= n; i++)
        {
            x = get(), y = get();
            Add(i, x, y), r += y;
        }


        while(l < r - 1)
        {
            K = l + r >> 1;
            if(Check()) r = K;
            else l = K;
        }

        K = l;
        if(Check()) printf("%lld", l);
        else printf("%lld", r);
}
View Code

 

  • 广泛的交换
posted @ 2017-11-25 21:34  BearChild  阅读(490)  评论(0编辑  收藏  举报