BZOJ #2238. Mst (最小生成树+树链剖分+线段树)

BZOJ #2238. Mst (最小生成树+树链剖分+线段树)

Description

给出一个N个点M条边的无向带权图,以及Q个询问,每次询问在图中删掉一条边后图的最小生成树。(各询问间独立,每次询问不对之后的询问产生影响,即被删掉的边在下一条询问中依然存在)

Input

第一行两个正整数N,M(N<=50000,M<=100000)表示原图的顶点数和边数。

下面M行,每行三个整数X,Y,W描述了图的一条边(X,Y),其边权为W(W<=10000)。保证两点之间至多只有一条边。

接着一行一个正整数Q,表示询问数。(1<=Q<=100000)

下面Q行,每行一个询问,询问中包含一个正整数T,表示把编号为T的边删掉(边从1到M按输入顺序编号)。

思路:

我们可以首先对图跑一个最小生成树算法得到图的MST,

如果无法的得出生成树,那么直接特判接下来每一个输出都是无解。

如果有生成树,那么对于所有非生成树上的边\((v,u)\)

它可以代替最小生成树上\(v->u\) 这个路径上所有边

那么我们可以在建立MST之后,进行树链剖分,同时建立线段树,功能需要有:

可以区间修改最小值,区间查询最小值。

然后对于每一个非树边,我们可以区间更新路径区间中的最小值,(注意这里需要一个边权转点权的操作)。

那么对于每一个边被删除之后的MST,分为两种情况:

1、该边是非树边,那么答案还是最初的MST。

2、该边是树边,那么该边应该替换为非树边构成的路径覆盖掉该边的所有非树边中权值最小的那个。显然我们可以直接通过维护的线段树得到那个最小权值,如果权值为无穷大(建立线段树的初始值),那么就代表没有非树边可以cover该树边,答案为无解。

// 本题有一个巨坑的地方,题面说保证两点之间至多只有一条边,但是数据中却有这样的情况,本人用assert亲测。

代码:

#include <bits/stdc++.h>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <vector>
#include <iomanip>
#define ALL(x) (x).begin(), (x).end()
#define sz(a) int(a.size())
#define all(a) a.begin(), a.end()
#define rep(i,x,n) for(int i=x;i<n;i++)
#define repd(i,x,n) for(int i=x;i<=n;i++)
#define pii pair<int,int>
#define pll pair<long long ,long long>
#define gbtb ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
#define MS0(X) memset((X), 0, sizeof((X)))
#define MSC0(X) memset((X), '\0', sizeof((X)))
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define eps 1e-6
#define chu(x) cout<<"["<<#x<<" "<<(x)<<"]"<<endl
using namespace std;
typedef long long ll;
ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
ll lcm(ll a, ll b) {return a / gcd(a, b) * b;}
ll powmod(ll a, ll b, ll MOD) {ll ans = 1; while (b) {if (b % 2)ans = ans * a % MOD; a = a * a % MOD; b /= 2;} return ans;}
inline void getInt(int* p);
const int maxn = 100000 + 10;
const int inf = 0x3f3f3f3f;
/*** TEMPLATE CODE * * STARTS HERE ***/
int n, m;
int root;
int cnt;// 编号用的变量
int top[maxn];// 所在重链的顶点编号
int id[maxn];//节点的新编号。
typedef pair<int, ll> pil;
std::vector<pil> son[maxn];
int SZ[maxn];// 子数大小
int wson[maxn];// 重儿子
int fa[maxn];// 父节点
int dep[maxn];// 节点的深度
void dfs1(int id, int pre, int step) // 维护出sz,wson,fa,dep
{
    dep[id] = step;
    fa[id] = pre;
    SZ[id] = 1;
    int  maxson = -1;
    for (auto x : son[id])
    {
        if (x.fi != pre)
        {
            dfs1(x.fi, id, step + 1);
            SZ[id] += SZ[x.fi];
            if (SZ[x.fi] > maxson)
            {
                maxson = SZ[x.fi];
                wson[id] = x.fi;
            }
        }
    }
}

//处理出top[],wt[],id[]
void dfs2(int u, int topf)
{
    id[u] = ++cnt;
    top[u] = topf;
    if (!wson[u]) // 没儿子时直接结束
    {
        return ;
    }
    dfs2(wson[u], topf); // 先处理重儿子
    for (auto x : son[u])
    {
        if (x.fi == wson[u] || x.fi == fa[u]) //只处理轻儿子
        {
            continue;
        }
        dfs2(x.fi, x.fi); // 每个轻儿子以自己为top
    }
}

struct node
{
    int l, r;
    int sum;
    int laze;
} segment_tree[maxn << 2];

void build(int rt, int l, int r)
{
    segment_tree[rt].l = l;
    segment_tree[rt].r = r;
    segment_tree[rt].laze = inf;
    segment_tree[rt].sum = inf;
    if (l == r)
    {
        return;
    }
    int mid = (l + r) >> 1;
    build(rt << 1, l, mid);
    build(rt << 1 | 1, mid + 1, r);
}
void push_down(int rt)
{
    if (segment_tree[rt].laze != inf)
    {
        int num = segment_tree[rt].laze;
        segment_tree[rt << 1].sum = min(num, segment_tree[rt << 1].sum);
        segment_tree[rt << 1].laze = min(num, segment_tree[rt << 1].laze);
        segment_tree[rt << 1 | 1].sum = min(num, segment_tree[rt << 1 | 1].sum);
        segment_tree[rt << 1 | 1].laze = min(num, segment_tree[rt << 1 | 1].laze);
        segment_tree[rt].laze = inf;
    }
}
void update(int rt, int l, int r, int w)
{
    if (segment_tree[rt].r < l || segment_tree[rt].l > r)
        return ;
    if (l <= segment_tree[rt].l && segment_tree[rt].r <= r)
    {
        segment_tree[rt].sum = min(segment_tree[rt].sum, w);
        segment_tree[rt].laze = min(segment_tree[rt].laze, w);
        return ;
    }
    push_down(rt);
    int mid = (segment_tree[rt].l + segment_tree[rt].r) >> 1;
    if (r <= mid)
    {
        return update(rt << 1, l, r, w);
    } else
    {
        if (l > mid)
        {
            update(rt << 1 | 1, l, r, w);
        }
        else
        {
            update(rt << 1, l, r, w);
            update(rt << 1 | 1, l, r, w);
        }
    }
}

int query(int rt, int l, int r)
{
    if (segment_tree[rt].r < l || segment_tree[rt].l > r)
        return inf;
    if (segment_tree[rt].l >= l && r >= segment_tree[rt].r)
    {
        return segment_tree[rt].sum;
    }
    push_down(rt);
    int mid = (segment_tree[rt].l + segment_tree[rt].r) >> 1;
    if (r <= mid)
    {
        return query(rt << 1, l, r);
    } else
    {
        if (l > mid)
            return query(rt << 1 | 1, l, r);
        else
            return min(query(rt << 1, l, r), query(rt << 1 | 1, l, r));
    }
}


void uprange(int x, int y, int num)
{
    while (top[x] != top[y])
    {
        if (dep[top[x]] < dep[top[y]])
        {
            swap(x, y);
        }
        update(1, id[top[x]], id[x], num);
        x = fa[top[x]];
    }
    if (dep[x] > dep[y])
        swap(x, y);
    update(1, id[x] + 1, id[y], num);
}

int qrange(int x, int y)
{
    int ans = inf;
    if (top[x] != top[y])
    {
        if (dep[top[x]] < dep[top[y]])
        {
            swap(x, y);
        }
        ans = min(ans, query(1, id[top[x]], id[x]));
        x = fa[top[x]];
    } else
    {
        if (dep[x] > dep[y])
            swap(x, y);
        ans = min(ans, query(1, id[x] + 1, id[y]));
    }
    return ans;
}
struct edge
{
    int f, t, val, id;
    bool operator < (const edge& bb) const
    {
        return val < bb.val;
    }
} info[maxn];
int far[maxn];
int dsu_sz[maxn];
void dsu_init(int n)
{
    repd(i, 0, n)
    {
        far[i] = i;
        dsu_sz[i] = 1;
    }
}
int findpar(int x)
{
    if (x == far[x])
    {
        return x;
    } else
    {
        return far[x] = findpar(far[x]);
    }
}
void mg(int x, int y)
{
    x = findpar(x);
    y = findpar(y);
    if (x == y)
        return;
    if (dsu_sz[x] > dsu_sz[y])
    {
        dsu_sz[x] += dsu_sz[y];
        far[y] = x;
    } else
    {
        dsu_sz[y] += dsu_sz[x];
        far[x] = y;
    }
}
bool isok = 0;
int ans = 0;
bool vis[maxn];
void solve()
{
    sort(info + 1, info + 1 + m);
    dsu_init(n);
    int cnt = 0;
    repd(i, 1, m)
    {
        int x = findpar(info[i].f);
        int y = findpar(info[i].t);
        int w = info[i].val;
        if (x != y)
        {
            ans += w;
            cnt++;
            son[info[i].f].push_back(mp(info[i].t, w));
            son[info[i].t].push_back(mp(info[i].f, w));
            mg(x, y);
        } else
        {
            vis[info[i].id] = 1;
        }
    }
    isok = cnt == n - 1;
}
edge ed[maxn];
int main()
{
    cin >> n >> m;
    root = 1;
    int u, v, val;
    repd(i, 1, m)
    {
        cin >> u >> v >> val;
        info[i] = (edge) {u, v, val, i};
        ed[i] = info[i];
    }
    solve();
    if (isok == 0)
    {
        int q;
        cin >> q;
        while (q--)
        {
            int i;
            cin >> i;
            cout << "Not connected" << endl;
        }
        return 0;
    }
    dfs1(root, 0, 1);
    dfs2(root, root);
    build(1, 1, n);
    repd(i, 1, m)
    {
        if (!vis[i])
            continue;
        auto e = ed[i];
        uprange(e.t, e.f, e.val);
    }
    int q, op, x, y;
    cin >> q;
    while (q--)
    {
        cin >> op;
        x = ed[op].f;
        y = ed[op].t;
        if (vis[op])
        {
            cout << ans << endl;
            continue;
        }
        int res = qrange(x, y);
        if (res == inf)
        {
            cout << "Not connected" << endl;
        } else
        {
            cout << ans - ed[op].val + res << endl;
        }
    }

    return 0;
}

posted @ 2020-09-04 21:55  茄子Min  阅读(243)  评论(0编辑  收藏  举报