【NOIP2015提高组】运输计划

https://daniu.luogu.org/problem/show?pid=2680

使完成所有运输计划的时间最短,也就是使时间最长的运输计划耗时最短。最大值最小问题考虑用二分答案,每次check(mid)检查时间最长的运输计划耗时是否小于等于mid,二分出使得check(mid)==true的最小mid值。

check函数怎么写是本题的难点。
耗时小于mid的运输计划不会影响check的结果。耗时大于mid的运输计划肯定需要改造他们的共同边才有可能使它们耗时都小于mid,而有多条共同边的时候肯定是改权值最大的更合算。如果改造了这条边可以使得原来时间最长的运输计划耗时也小于mid,则返回true,否则返回false。

所以读入数据时需要预处理下每个运输计划的耗时。

问题就在于怎么判断是否有公共边了。首先无根树转有根树,为了方便判断把边的权值放到子结点上。
可以用树剖+线段树把每一条路径上经过的所有点(除LCA)计数加一,然后看计数最大的点。但是更快更方便的方法是树上差分——对于每一条路径给两端的结点计数加1,给LCA计数减2。统计完之后做一遍树上前缀和,还可以在这个过程顺便求出计数最多的点。

注意这题卡常数非常厉害,记得用上快速读入、链式前向星、启发式合并+路径压缩的并查集。

#include <algorithm>
#include <cstring>
#include <iostream>
#include <vector>
#define maxn 300005
using namespace std;
void scan(int &x)
{
    x = 0;
    char c;
    bool flag = false;
    while (!isdigit(c = getchar()))
    {
        if (c == '-')
            flag = true;
        if (c == EOF)
            return;
    }
    do
        x = x * 10 + c - '0';
    while (isdigit(c = getchar()));
    if (flag)
        x = -x;
}

int n, m;
struct edge
{
    int next, to, weight;
} edges[maxn * 2];
int head[maxn], ecnt = 1;
void add_edge(int u, int v, int w)
{
    edges[ecnt].to = v;
    edges[ecnt].weight = w;
    edges[ecnt].next = head[u];
    head[u] = ecnt++;

    edges[ecnt].to = u;
    edges[ecnt].weight = w;
    edges[ecnt].next = head[v];
    head[v] = ecnt++;
}
int weight[maxn], parent[maxn], length[maxn];
void build_tree(int v, int fr, int wei)
{
    parent[v] = fr;
    weight[v] = wei;
    length[v] = length[fr] + wei;
    for (int i = head[v]; i; i = edges[i].next)
    {
        if (edges[i].to != fr)
        {
            build_tree(edges[i].to, v, edges[i].weight);
        }
    }
}

namespace djs
{
int djs_parent[maxn];
void init()
{
    for (int i = 1; i <= n; i++)
        djs_parent[i] = -1;
}
int find(int x)
{
    if (djs_parent[x] < 0)
        return x;
    else
        return djs_parent[x] = find(djs_parent[x]);
}
void merge(int x, int y)
{
    x = find(x);
    y = find(y);
    djs_parent[y] = x;
}
}
int from[maxn], to[maxn], cost[maxn], lca[maxn];
vector<int> query_index[maxn];
bool visited[maxn];
void get_lca(int v)
{
    for (int i = head[v]; i; i = edges[i].next)
    {
        int w = edges[i].to;
        if (w != parent[v])
        {
            get_lca(w);
            djs::merge(v, w);
        }
    }
    visited[v] = true;

    for (int i = 0; i < query_index[v].size(); i++)
    {
        int q = query_index[v][i];
        int w = (from[q] == v) ? to[q] : from[q];
        if (visited[w])
            lca[q] = djs::find(w);
    }
}

int mark[maxn], maxmark;
void markdown(int i)
{
    mark[from[i]]++;
    mark[to[i]]++;
    mark[lca[i]] -= 2;
}
void push_down(int v)
{
    for (int i = head[v]; i; i = edges[i].next)
    {
        int w = edges[i].to;
        if (w != parent[v])
        {
            push_down(w);
            mark[v] += mark[w];
        }
    }
    if (mark[v] > mark[maxmark])
        maxmark = v;
    else if (mark[v] == mark[maxmark] && weight[v] > weight[maxmark])
        maxmark = v;
}
bool check(int k)
{
    for (int i = 1; i <= n; i++)
        mark[i] = 0;
    int maxcost = 0;
    int cnt = 0;
    for (int i = 1; i <= m; i++)
    {
        if (cost[i] > k)
        {
            cnt++;
            markdown(i);
            maxcost = max(maxcost, cost[i]);
        }
    }
    if (cnt == 0)
        return true;
    maxmark = 0;
    push_down(1);
    if (mark[maxmark] >= cnt && maxcost - weight[maxmark] <= k)
        return true;
    else
        return false;
}
int main()
{
    scan(n);
    scan(m);
    int a, b, c;
    for (int i = 1; i < n; i++)
    {
        scan(a);
        scan(b);
        scan(c);
        add_edge(a, b, c);
    }
    build_tree(1, 0, 0);
    for (int i = 1; i <= m; i++)
    {
        scan(from[i]);
        scan(to[i]);
        query_index[from[i]].push_back(i);
        query_index[to[i]].push_back(i);
    }
    djs::init();
    get_lca(1);
    for (int i = 1; i <= m; i++)
    {
        cost[i] = length[from[i]] + length[to[i]] - 2 * length[lca[i]];
    }

    int l = 0, r = 1000 * maxn, mid;
    while (l < r)
    {
        mid = (l + r) / 2;
        if (check(mid))
            r = mid;
        else
            l = mid + 1;
    }
    cout << l << endl;
    return 0;
}

 

posted @ 2017-09-16 11:22  ssttkkl  阅读(498)  评论(0编辑  收藏  举报