Loading

POJ-3417 Network

Network

树上差分 + 倍增LCA

树上添加一条边,会形成一个环,环上的边如果随便断开一条,仍然是属于同一个集合里

因此我们可以计算一下每一条边在多少个环内,最后再对每一条边进行判断,有多少种删除边的方法,计算的话可以用树上差分 \(O(1)\) 解决

假设一条边处在 \(x\) 个环内

  1. \(x = 0\): 删除该边,然后可以选择任意一个新添加的边进行删除,所以贡献就是 \(m\)

  2. \(x = 1\): 删除该边,同时也要删除所处环上的那条新加边,因此贡献就是 \(1\)

  3. \(x \ge 2\): 不管删除那一条新添加的边的,都不会贡献答案

坑坑坑坑点(如果 TLE or RE):

  1. 不要用邻接表,会 TLE,用链式前向星就可以了

  2. 倍增要看看数组大小对不对

  3. 不用扩展栈空间

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <queue>
using namespace std;
typedef long long ll;
const int maxn = 2e5 + 10;
int dep[maxn], fa[maxn][21];
ll diff[maxn];
int top[maxn], nexx[maxn], to[maxn];

struct node
{
    int now, pre, d;
    node(){}
    node(int _now, int _pre, int _d){now = _now; pre = _pre; d = _d;}
};

void dfs1(int now, int pre, int d)
{
    fa[now][0] = pre;
    dep[now] = d;
    for(int i=top[now]; i; i=nexx[i])
    {
        int nex = to[i];
        if(nex == pre) continue;
        dfs1(nex, now, d + 1);
    }
}

void init(int n)
{
    dfs1(1, 1, 1);
    for(int i=1; i<=20; i++)
        for(int j=1; j<=n; j++)
            fa[j][i] = fa[fa[j][i-1]][i-1];
}

int LCA(int a, int b)
{
    if(dep[a] < dep[b]) swap(a, b);
    int dif = dep[a] - dep[b];
    for(int i=20; i>=0; i--)
    {
        if(dif >= (1 << i))
        {
            dif -= 1 << i;
            a = fa[a][i];
        }
    }
    if(a == b) return a;
    for(int i=20; i>=0; i--)
    {
        if(fa[a][i] != fa[b][i])
        {
            a = fa[a][i];
            b = fa[b][i];
        }
    }
    return fa[a][0];
}

inline void add(int a, int b)
{
    int lca = LCA(a, b);
    diff[a]++;
    diff[b]++;
    diff[lca] -= 2;
}

void dfs2(int now, int pre)
{
    for(int i=top[now]; i; i=nexx[i])
    {
        int nex = to[i];
        if(nex == pre) continue;
        dfs2(nex, now);
        diff[now] += diff[nex];
    }
}

int tp = 0;
inline void add_line(int a, int b)
{
    tp++;
    nexx[tp] = top[a];
    top[a] = tp;
    to[tp] = b;
}

int main()
{
    int n, m;
    scanf("%d%d", &n, &m);
    for(int i=1; i<n; i++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        add_line(a, b);
        add_line(b, a);
    }
    init(n);
    for(int i=0; i<m; i++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        add(a, b);
    }
    dfs2(1, 1);
    ll ans = 0;
    for(int i=2; i<=n; i++)
    {
        if(diff[i] == 0) ans += m;
        else if(diff[i] == 1) ans++;
    }
    printf("%lld\n", ans);
    return 0;
}
posted @ 2022-07-31 16:27  dgsvygd  阅读(19)  评论(0编辑  收藏  举报