BZOJ1977: [BeiJing2010组队]次小生成树 Tree
Description
小 C 最近学了很多最小生成树的算法,Prim 算法、Kurskal 算法、消圈算法等等。 正当小 C 洋洋得意之时,小 P 又来泼小 C 冷水了。小 P 说,让小 C 求出一个无向图的次小生成树,而且这个次小生成树还得是严格次小的,也就是说: 如果最小生成树选择的边集是 EM,严格次小生成树选择的边集是 ES,那么需要满足:(value(e) 表示边 e的权值) 这下小 C 蒙了,他找到了你,希望你帮他解决这个问题。
Input
第一行包含两个整数N 和M,表示无向图的点数与边数。 接下来 M行,每行 3个数x y z 表示,点 x 和点y之间有一条边,边的权值为z。
Output
包含一行,仅一个数,表示严格次小生成树的边权和。(数据保证必定存在严格次小生成树)
Sample Input
5 6
1 2 1
1 3 2
2 4 3
3 5 4
3 4 3
4 5 6
Sample Output
11
HINT
数据中无向图无自环; 50% 的数据N≤2 000 M≤3 000; 80% 的数据N≤50 000 M≤100 000; 100% 的数据N≤100 000 M≤300 000 ,边权值非负且不超过 10^9 。
Solution
好恶心的一道题...
首先要知道有这么一个定理:
严格次小MST一定只有一条边和MST不同。
然后这题就可以做了。
先随便找出来一棵MST,然后把树建出来,称这\(n-1\)条边为“树边”,其他的边为“非树边”。
则如果把一条非树边\((x,y)\)接到树上,那么就会和树上x到y的路径产生一个环,为了保持树的形态,所以必须删掉树上x到y的路径上的一条边,而因为我萌要维护的是严格次小MST,所以需要知道树上x到y路径中的最大值和次大值(如果非树边\((x,y)\)的边权等于最大值,就必须换掉次大值,因为我萌要求的是严格次小MST)。
如何求树链的最大值和次大值?这个问题可以树上倍增解决。(当然也可以树剖,但是复杂度是两个log,而且更难写)。
设\(g[x,i,0/1]\)表示节点\(x\)向上\(2^i\)个祖先路径中的最大值和次大值。
则显然有\(g[x,i,0]=edge(i,fa[i]),g[x,i,1]=-∞\),\(g[x,i,0]=max(g[x,i-1,0],g[f[x,i-1],i-1,0])\)
对次大值分类讨论一下:
于是预处理完之后,分别处理每条非树边\((x,y)\),在求x和y的lca的同时类似预处理那样把路径上的最大值和次大值求出来(方法完全类似所以不想写了...,反正就是分个三类)
对每次替换得到的次小MST取个min,就是答案了。
代码贼长...
#include <bits/stdc++.h>
#define ll long long
#define il inline
const ll inf = 1e18;
namespace io {
#define in(a) a = read()
#define out(a) write(a)
#define outn(a) out(a), putchar('\n')
#define I_int ll
inline I_int read() {
I_int x = 0, f = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = x * 10 + c - '0';
c = getchar();
}
return x * f;
}
char F[200];
inline void write(I_int x) {
if (x == 0) return (void) (putchar('0'));
I_int tmp = x > 0 ? x : -x;
if (x < 0) putchar('-');
int cnt = 0;
while (tmp > 0) {
F[cnt++] = tmp % 10 + '0';
tmp /= 10;
}
while (cnt > 0) putchar(F[--cnt]);
}
#undef I_int
}
using namespace io;
using namespace std;
#define N 300010
int n = read(), m = read(), lim;
int cnt, head[N], fa[N], f[N][20], dep[N];
ll g[N][25][2];
struct Node {
int x, y, v, flag;
}a[N];
struct edge {
int to, nxt, v;
}e[N<<3];
void ins(int u, int v, int w) {
e[++cnt] = (edge) {v, head[u], w};
head[u] = cnt;
}
bool cmp(Node a, Node b) {
return a.v < b.v;
}
void dfs(int u) {
for(int i = head[u]; i; i = e[i].nxt) {
if(e[i].to == f[u][0]) continue; int v = e[i].to;
f[v][0] = u; g[v][0][0] = 1ll*e[i].v, g[v][0][1] = -inf;
dep[v] = dep[u] + 1;
for(int j = 1; j <= lim; ++j) {
f[v][j] = f[f[v][j-1]][j-1];
g[v][j][0] = max(g[v][j-1][0], g[f[v][j-1]][j-1][0]);
if(g[v][j-1][0] == g[f[v][j-1]][j-1][0]) g[v][j][1] = max(g[v][j-1][1], g[f[v][j-1]][j-1][1]);
else if(g[v][j-1][0] > g[f[v][j-1]][j-1][0]) g[v][j][1] = max(g[v][j-1][1], g[f[v][j-1]][j-1][0]);
else g[v][j][1] = max(g[v][j-1][0], g[f[v][j-1]][j-1][1]);
}
dfs(v);
}
}
void lca(int x, int y, ll &t1, ll &t2) {
t1 = -inf; t2 = -inf;
if(dep[x] < dep[y]) swap(x, y);
for(int i = lim; i >= 0; --i) {
if(dep[f[x][i]] >= dep[y]) {
if(t1 == g[x][i][0]) t2 = max(t2, g[x][i][1]);
else if(t1 < g[x][i][0]) t2 = max(t1, g[x][i][1]), t1 = g[x][i][0];
else t2 = max(t2, g[x][i][0]);
x = f[x][i];
}
}
if(x == y) return;
for(int i = lim; i >= 0; --i) {
if(f[x][i] != f[y][i]) {
if(t1 == g[x][i][0]) t2 = max(t2, g[x][i][1]);
else if(t1 < g[x][i][0]) t2 = max(t1, g[x][i][1]), t1 = g[x][i][0];
else t2 = max(t2, g[x][i][0]);
x = f[x][i];
if(t1 == g[y][i][0]) t2 = max(t2, g[y][i][1]);
else if(t1 < g[y][i][0]) t2 = max(t1, g[y][i][1]), t1 = g[y][i][0];
else t2 = max(t2, g[y][i][0]);
y = f[y][i];
}
}
if(t1 == g[x][0][0]) t2 = max(t2, g[x][0][1]);
else if(t1 < g[x][0][0]) t2 = max(t1, g[x][0][1]), t1 = g[x][0][0];
else if(t2 < g[x][0][0]) t2 = g[x][0][0];
if(t1 == g[y][0][0]) t2 = max(t2, g[y][0][1]);
else if(t1 < g[y][0][0]) t2 = max(t1, g[y][0][1]), t1 = g[y][0][0];
else if(t2 < g[y][0][0]) t2 = g[y][0][0];
}
int find(int x) {
if(fa[x] == x) return x;
return fa[x] = find(fa[x]);
}
int main() {
for(int i = 1; i <= n; ++i) fa[i] = i;
for(int i = 1; i <= m; ++i) {
int x = read(), y = read(), v = read();
a[i] = (Node){x, y, v, 0};
}
sort(a+1,a+m+1,cmp); ll sum = 0;
for(int i = 1, tot = 0; tot < n - 1 && i <= m; ++i) {
int x = find(a[i].x), y = find(a[i].y);
if(x != y) {
fa[y] = x;
sum += 1ll*a[i].v;
ins(a[i].x, a[i].y, a[i].v);
ins(a[i].y, a[i].x, a[i].v);
a[i].flag = 1;
++tot;
}
}
lim = (int)(log(n) / log(2)) + 1;
for(int i = 1; i <= lim; ++i) g[1][i][0] = g[1][i][1] = -inf;
dep[1] = 1; dfs(1);
ll ans = inf;
for(int i = 1; i <= m; ++i) {
if(a[i].flag) continue;
ll mx = 0, se_mx = 0;
lca(a[i].x, a[i].y, mx, se_mx);
if(a[i].v == mx) ans = min(ans, sum + (ll)a[i].v - se_mx);
else ans = min(ans, sum + (ll)a[i].v - mx);
}
outn(ans);
}