NOIP 2015

Day2

T3 运输计划

题意:有一颗\(n\)个点的树,每条边有长度,给出树上m条路径,你可以选择一条边,将其长度改为0,求这m条路径中最长的最短可以是多少

\(n,m \le 3\times10^5\),边长不超过\(1000\)

解法:

\(1.\)二分答案,把一开始长度超限的求交集,若交集为0,无解,否则选交集中最长的边

\(2.\)树上差分求交(端点$ -1, LCA + 2$)

时间复杂度:\(O(nlog^2n)\)

\(1.\)考虑两两路径的交集,则发现它一定在这两路径端点的\(4\)个LCA中最深的两个LCA的路径中(本题可以这样做,其它题还要考虑\(4\)个LCA重合的情况等)。

\(2.\)这样做发现二分的用处不大,可以把边从大到小排序,依次求交,然后枚举答案即可,如果交集为空或者最长的长度不够用了,就输出答案。可以用\(RMQ\)优化求\(LCA\)

时间复杂度:\(O(nlogn)\) 听说用一些神奇的方法可以将\(RMQ\)的预处理优化到\(O(n)\),使总时间复杂度达到\(O(n)\)

对于第一种\(O(nlog^2n)\)的解法,我们发现可以直接预处理出每对的LCA,这样时间复杂度将会降到\(O(nlogn)\)

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define maxn 300100
#define RE register 
int n, m;
int st[maxn], ed[maxn], len[maxn];
int fir[maxn], nxt[maxn * 2], vv[maxn * 2], edge[maxn * 2];
int tot = 0;
int read()
{
    int ret = 0;
    char ch = getchar();
    while(ch < '0' || ch > '9')
    {
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9')
    {
        ret = ret * 10 + ch - '0';
        ch = getchar();
    }
    return ret;
}
void add(int u, int v, int w)
{
    nxt[++tot] = fir[u];
    fir[u] = tot;
    vv[tot] = v;
    edge[tot] = w;
}
int dep[maxn], f[maxn][25], g[maxn][25];
void Deal_first(int u, int fa)
{
 //   printf("u = %d fa = %d\n", u, fa);
    dep[u] = dep[fa] + 1;
    for(RE int i = 0; i <= 19; i++)
    {
        f[u][i + 1] = f[f[u][i]][i];
        g[u][i + 1] = g[u][i] + g[f[u][i]][i];
    }
    for(RE int i = fir[u]; i; i = nxt[i])
    {
        int v = vv[i];
        if(v == fa) continue;
        f[v][0] = u;
        g[v][0] = edge[i];
        Deal_first(v, u);
    }
}
int LCA(int x, int y, int &dis)
{
    if(dep[x] < dep[y]) swap(x, y);
    for(RE int i = 20; i >= 0; i--)
    {
        if(dep[f[x][i]] >= dep[y])
        {
            dis += g[x][i];
    //        printf("g[x][i] = %d\n", g[x][i]);
            x = f[x][i];
        }
   //     printf("x = %d y = %d\n", x, y);
        if(x == y)
        {
            return x;
        }
    }
    for(RE int i = 20; i >= 0; i--)
    {
        if(f[x][i] != f[y][i])
        {
            dis += g[x][i];
            dis += g[y][i];
            x = f[x][i];
            y = f[y][i];
    //        printf("x = %d y = %d\n", x, y);
        }
    }
    dis += (g[x][0] + g[y][0]);
    return f[x][0];
}
int val[maxn], dp[maxn * 2], from[maxn], lc[maxn];
void dfs(int u, int fa)
{
    dp[from[u]] += val[u];
    for(RE int i = fir[u]; i; i = nxt[i])
    {
        int v = vv[i];
        if(v == fa) continue;
        from[v] = i;
        dfs(v, u);
        dp[from[u]] += dp[from[v]];
    }
}
int lenth;
int check(int x)
{
    memset(val, 0, sizeof(val));
    memset(dp, 0, sizeof(dp));
    int cnt = 0;
    for(RE int i = 1; i <= m; i++)
    {
        if(len[i] > x)
        {
            cnt++;
            int u = st[i], v = ed[i];
            int xx = 0;
            val[u] += 1; val[v] += 1;
            val[lc[i]] -= 2;
        }
   //     else break;
    }
    dfs(1, 0);
    int maxx = -1;
    for(RE int i = 1; i <= tot; i++)
    {
    //    printf("dp[%d] = %d ", i, dp[i]);
        if(dp[i] >= cnt)
        {
            maxx = max(maxx, edge[i]);
        }
    }
   // printf("\n");
   // printf("x = %d cnt = %d maxx = %d\n", x, cnt, maxx);
    if(!cnt) return 1;
    if(maxx == -1 || lenth - maxx > x) return 0;
    return 1;
}
bool cmp(int x, int y)
{
    return x > y;
}
int main()
{
    n = read(); m = read();
    for(RE int i = 1; i < n; i++)
    {
        int u, v, w;
        u = read(); v = read(); w = read();
        add(u, v, w); add(v, u, w);
    }
    for(RE int i = 1; i <= m; i++)
    {
        st[i] = read(); ed[i] = read();
    }
    Deal_first(1, 0);
    for(RE int i = 1; i <= m; i++)
    {
        int dis = 0;
        int fa = LCA(st[i], ed[i], dis);
        len[i] = dis; lenth = max(lenth, len[i]);
        lc[i] = fa;
    //    printf("u = %d v = %d fa = %d\n", st[i], ed[i], fa);
        //len[i] = dep[st[i]] + dep[ed[i]] - dep[LCA(st[i], ed[i])] * 2;
    }
   // sort(len + 1, len + m + 1, cmp);
   // for(int i = 1; i <= m; i++) printf("len[%d] = %d\n", i, len[i]);
    int l = 0, r = 3e8, ans = 0;
    while(l <= r)
    {
        int mid = (l + r) >> 1;
        if(check(mid) == 1)
        {
            ans = mid;
            r = mid - 1;
        }
        else l = mid + 1;
    }
    printf("%d\n", ans);
    return 0;
}
posted @ 2019-10-10 17:29  Akaina  阅读(102)  评论(0编辑  收藏  举报