(Bzoj1977)次小生成树

试题描述
小C最近学了很多最小生成树的算法,Prim算法、 Kurskal算法、消圈算法等,正当小C
得意之时,小P又来泼小C冷水了。小P说,让小C求出一个无向图的次小生成树,而且这次
生成树还得是严格次小的,也就是说:如果最小生成树选择的边集是EM,严格次小生成树选的
边集是Es,那么需要满足:( value(e)表示边e的权值)
∑ value(e)<∑ value(e)
这下小C蒙了,他找到了你,希望你帮他解决这个问题
输入
第一行包含两个整数N和M,表示无向图的点数与边数
接下来M行,每行3个数x,y,z,表示点x和点y之间有一条边,边的权值为2
输出
包含一行,仅一个数,表示严格次小生成树的边权和。(数据保证存在严格次小生成树)
输入示例
5 6
1 2 1
1 3 2
2 4 3
3 5 4
3 4 3
4 5 6
输出示例
11

按理来讲是一道挺菜的题

然而我们学校的OJ卡常比较严重

交了五十多变也没过,但是在Loj上可以

权且当做看个思想吧

我们知道既然是生成树,那么肯定加一条边后会形成环

然后我们再删掉环上与加进去的边权值不同的最长边,就得到了次小生成树

然后再所有的情况中选出一个最小的

其中因为数据范围的缘故,需要一个LCA来维护

下面给出代码:

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<string>
#include<cmath>
using namespace std;
inline long long rd()
{
    long long x=0,f=1;
    char ch=getchar();
    for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-1;
    for(;isdigit(ch);ch=getchar()) x=(x<<3)+(x<<1)+ch-'0';
    return x*f;
}
inline void write(long long x)
{
    if(x<0) putchar('-'),x=-x;
    if(x>9) write(x/10);
    putchar(x%10+'0');
    return ;
}
int n,m;
struct node{
    int u,v;
    long long w;
}s[600006];
bool cmp(const node x,const node y){
    return x.w<y.w;
}
int f[300006]; 
int head[300006],to[600006],nxt[600006],dis[600006];
int fa[300006][30];
int dep[300006];
long long maxn[300006][30];
long long maxm[300006][30];
int vis[300006];
int total=0;
void add(int x,int y,long long z){
    total++;
    to[total]=y;
    dis[total]=z;
    nxt[total]=head[x];
    head[x]=total;
    return ;
}
int getf(int v){
    if(v==f[v]) return v;
    return f[v]=getf(f[v]);
}
void dfs(int x,int la){
    dep[x]=dep[la]+1;
    for(register int e=head[x];e;e=nxt[e]){
        if(to[e]!=la){
            fa[to[e]][0]=x;
            maxn[to[e]][0]=dis[e];
            maxm[to[e]][0]=-0x7fffffff;
            dfs(to[e],x);
        }
    }
    return ;
}
int LCA(int x,int y,long long w){
    if(dep[x]<dep[y]) swap(x,y);
    long long maxs=-0x7fffffff;
    int de=dep[x]-dep[y];
    for(register int i=16;i>=0;i--){
        if((1<<i)&de){
            if(maxn[x][i]==w) maxs=maxs>maxm[x][i]?maxs:maxm[x][i];
            else maxs=maxs>maxn[x][i]?maxs:maxn[x][i];
            x=fa[x][i];
        }
    }
    if(x==y){
        if(maxs==-0x7fffffff) return 0;
        return w-maxs;
    }
    for(register int i=16;i>=0;i--){
        if(fa[x][i]!=fa[y][i]){
            if(maxn[x][i]==w) maxs=maxs>maxm[x][i]?maxs:maxm[x][i];
            else maxs=maxs>maxn[x][i]?maxs:maxn[x][i];
            if(maxn[y][i]==w) maxs=maxs>maxm[y][i]?maxs:maxm[y][i];
            else maxs=maxs>maxn[y][i]?maxs:maxn[y][i];
            x=fa[x][i];
            y=fa[y][i];
        }
    }
    if(maxn[x][0]==w) maxs=maxs>maxm[x][0]?maxs:maxm[x][0];
    else maxs=maxs>maxn[x][0]?maxs:maxn[x][0];
    if(maxn[y][0]==w) maxs=maxs>maxm[y][0]?maxs:maxm[y][0];
    else maxs=maxs>maxn[y][0]?maxs:maxn[y][0];
    if(maxs==-0x7fffffff) return 0;
    return w-maxs;
}
int main()
{
    n=rd();
    m=rd();
    for(register int i=1;i<=m;i++){
        s[i].u=rd();
        s[i].v=rd();
        s[i].w=rd();
    }
    for(register int i=1;i<=n;i++) f[i]=i;
    sort(s+1,s+m+1,cmp);
    int cnt=0;
    long long sum=0;
    for(register int i=1;i<=m;i++){
        int x=getf(s[i].v),y=getf(s[i].u);
        if(x!=y){
            vis[i]=1;
            f[x]=y;
            cnt++;
            sum+=s[i].w;
            add(s[i].v,s[i].u,s[i].w);
            add(s[i].u,s[i].v,s[i].w);
        }
        if(cnt==n-1) break;
    }
    dfs(1,0);
    for(register int j=1;j<=16;j++){
        for(int e=1;e<=n;e++){
            fa[e][j]=fa[fa[e][j-1]][j-1];
            maxn[e][j]=maxn[e][j-1]>maxn[fa[e][j-1]][j-1]?maxn[e][j-1]:maxn[fa[e][j-1]][j-1];
            maxm[e][j]=maxm[e][j-1]>maxm[fa[e][j-1]][j-1]?maxm[e][j-1]:maxm[fa[e][j-1]][j-1];
            if(maxn[e][j-1]>maxn[fa[e][j-1]][j-1]) maxm[e][j]=maxm[e][j]>maxn[fa[e][j-1]][j-1]?maxm[e][j]:maxn[fa[e][j-1]][j-1];
            if(maxn[e][j-1]<maxn[fa[e][j-1]][j-1]) maxm[e][j]=maxm[e][j]>maxn[e][j-1]?maxm[e][j]:maxn[e][j-1];
        }
    }
    long long ans=0x7fffffff;
    for(register int i=1;i<=m;i++){
        if(vis[i]) continue;
        long long num=LCA(s[i].u,s[i].v,s[i].w);
        if(!num) continue;
        ans=min(ans,num);
    }
    write(ans+sum);
    return 0;
}

 

posted @ 2018-09-28 21:05  Bruce--Wang  阅读(158)  评论(0编辑  收藏  举报