bzoj4326

二分+树剖+差分

之前的做法naive,莫名其妙的wa,明明uoj95分

看到最小最大上二分,树上路径问题直接剖,然后问题就转化成了一个判定问题,每次二分出最长路径长度,问能不能达到。那么我们就把所有长度大于二分出的d的路径拉出来,求出他们公共路径的最大长度,看减去能不能满足。那么现在的问题就转化成了求路径的交。我们可以利用bzoj4390的树上差分的方法解决,我们先把边转化为点,即每条边等价于他连接的深度较深的点,那么我们利用树上差分标记路径上的边,把路径上的边标记+1,这个我们只要在u,v,+1,lca,-2就行了,然后我们求一个前缀和,看子树权值是否等于路径数量就行了。

树上差分是一种重要的思想,大概有几种方法,

1.像这道题和bzoj4390,可以标记路径上的点/边

2.像noip2016day1t2,雨天的尾巴,利用树剖重链是一段连续的区间且最多只有log条重链,我们可以log时间修改路径的信息,这和上面的差分是不一样的,因为如果按上面的方法差分那么可能会把这条路径上的信息带到其他的点,而树剖保证dfs序连续,所以每条重链差分一下,那么就不会带入到其他点,这样一共会打logn个标记

原先是像魔法森林那样把边权转化为点权,时间爆炸

#include<bits/stdc++.h>
using namespace std;
const int N = 300010;
struct edge {
    int nxt, to, w;
} e[N << 1];
struct path {
    int u, v, w;
    bool friend operator < (path a, path b) { return a.w > b.w; }
} a[N];
namespace IO 
{
    const int Maxlen = N * 50;
    char buf[Maxlen], *C = buf;
    int Len;
    inline void read_in()
    {
        Len = fread(C, 1, Maxlen, stdin);
        buf[Len] = '\0';
    }
    inline void fread(int &x) 
    {
        x = 0;
        int f = 1;
        while (*C < '0' || '9' < *C) { if(*C == '-') f = -1; ++C; }
        while ('0' <= *C && *C <= '9') x = (x << 1) + (x << 3) + *C - '0', ++C;
        x *= f;
    }
    inline void read(int &x)
    {
        x = 0;
        int f = 1; char c = getchar();
        while(c < '0' || c > '9') { if(c == '-') f = -1; c = getchar(); }
        while(c >= '0' && c <= '9') { x = (x << 1) + (x << 3) + c - '0'; c = getchar(); }
        x *= f;
    }
    inline void read(long long &x)
    {
        x = 0;
        long long f = 1; char c = getchar();
        while(c < '0' || c > '9') { if(c == '-') f = -1; c = getchar(); }
        while(c >= '0' && c <= '9') { x = (x << 1ll) + (x << 3ll) + c - '0'; c = getchar(); }
        x *= f;
    } 
} using namespace IO;
int n, m, cnt = 1, tot, lim;
int tree[N], mark[N], mir[N], head[N], dis[N], size[N], son[N], dep[N], fa[N], top[N], Lca[N], w[N], in[N], out[N];
void link(int u, int v, int w)
{
    e[++cnt].nxt = head[u];
    head[u] = cnt;
    e[cnt].to = v;
    e[cnt].w = w;
}
void dfs(int u, int last)
{
    size[u] = 1;
    dis[u] = dis[last] + w[u];
    for(int i = head[u]; i; i = e[i].nxt) if(e[i].to != last)
    {
        dep[e[i].to] = dep[u] + 1;
        fa[e[i].to] = u;
        w[e[i].to] = e[i].w;
        dfs(e[i].to, u);
        size[u] += size[e[i].to];
        if(size[e[i].to] > size[son[u]]) son[u] = e[i].to;
    }
}
void dfs(int u, int acs, int last)
{
    in[u] = ++tot;
    mir[tot] = u;
    top[u] = acs;
    if(son[u]) dfs(son[u], acs, u);
    for(int i = head[u]; i; i = e[i].nxt) if(e[i].to != last && e[i].to != son[u]) dfs(e[i].to, e[i].to, u);
    out[u] = tot;
}
int lca(int u, int v)
{
    while(top[u] != top[v])
    {
        if(dep[top[u]] < dep[top[v]]) swap(u, v);
        u = fa[top[u]];
    }
    return dep[u] < dep[v] ? u : v;
}
int Dis(int u, int v)
{
    int x = lca(u, v);
    return dis[u] + dis[v] - 2 * dis[x]; 
}
void update(int x, int d)
{
    for(; x <= n; x += x & (-x)) tree[x] += d;
}
int query(int x)
{
    int ret = 0;
    for(; x; x -= x & (-x)) ret += tree[x];
    return ret;
}
bool check(int d)
{
    memset(mark, 0, sizeof(mark));
    int cou = 0, mx = 0;
    for(int i = 1; i <= m; ++i) 
    {
        if(a[i].w <= d) break;
        ++cou;
        mark[in[Lca[i]]] -= 2;
        ++mark[in[a[i].u]];
        ++mark[in[a[i].v]];
    }
    for(int i = 1; i <= n; ++i) mark[i] += mark[i - 1];
    for(int i = 1; i <= n; ++i) 
    {
        int tmp = mark[out[i]] - mark[in[i] - 1];
        if(tmp == cou) mx = max(mx, w[i]);
    }    
    return a[1].w - mx <= d;
}
int main()
{
    read_in();
    fread(n);
    fread(m);
    for(int i = 1; i < n; ++i)
    {
        int u, v, w;
        fread(u);
        fread(v);
        fread(w);
        link(u, v, w);
        link(v, u, w);
    }
    dfs(1, 0);
    dfs(1, 1, 0);
    for(int i = 1; i <= m; ++i) 
    {
        fread(a[i].u);
        fread(a[i].v);
        a[i].w = Dis(a[i].u, a[i].v);
        lim = max(lim, a[i].w);
    }    
    sort(a + 1, a + m + 1);
    for(int i = 1; i <= m; ++i) Lca[i] = lca(a[i].u, a[i].v);
    int l = -1, r = lim + 1, ans = 0;
    while(r - l > 1)
    {
        int mid = (l + r) >> 1;
        if(check(mid)) r = ans = mid;
        else l = mid;
    }
    printf("%d\n", ans);
    return 0;
}
View Code

 

posted @ 2017-09-15 11:56  19992147  阅读(173)  评论(0编辑  收藏  举报