【做题】CSA72G - MST and Rectangles——Borůvka&线段树
原文链接 https://www.cnblogs.com/cly-none/p/CSA72G.html
题意:有一个\(n \times n\)的矩阵\(A\),\(m\)次操作,每次在\(A\)上三角部分的一个子矩形中加上一个数。最后构造\(n\)个点的图\(G\),且对于所有\(i,j \ (i < j)\),边\((i,j)\)的边权为\(A_{i,j}\)。求图\(G\)的最小生成树的边权和。
\(n,m \leq 10^5\)
先把上三角矩阵补成邻接矩阵。这样每次操作就是加两个邻接矩阵的子矩形。
这种题目通常要对经典算法进行拓展。常用的最小生成树算法有Prim和Kruskal,然而在尝试之后我们发现,由于边权种类太多,Prim不行;同样Kruskal也难以提出排序后的边权。
但还有Borůvka。这个算法要求对每个联通块找到边权最小的邻边,还要合并联通块。后者用并查集能容易实现,现在仅考虑前者。
先想一个更简单的问题:对每个结点找到边权最小的邻边。这是简单的,我们只需要扫描邻接矩阵的每一行,这样每次矩形加都变成了两个区间加。用线段树维护最小值就好了。考虑原问题。这个最小值可能和这个点在同一个联通块内,因此,非常套路地,我们就再记录与最小值不在一个联通块内的次小值就可以了。
因为上述扫描线需要执行\(O(\log n)\)次,故复杂度为\(O(n \log^2 n)\)。
#include <bits/stdc++.h>
#define data DATA
using namespace std;
#define gc() getchar()
template <typename tp>
inline void read(tp& x) {
x = 0; char tmp; bool key = 0;
for (tmp = gc() ; !isdigit(tmp) ; tmp = gc())
key = (tmp == '-');
for ( ; isdigit(tmp) ; tmp = gc())
x = (x << 3) + (x << 1) + (tmp ^ '0');
if (key) x = -x;
}
typedef long long ll;
const int N = 100010;
const ll INF = 1ll << 60;
int n,m,uni[N],cnt;
ll ans;
int getfa(int pos) {
return pos == uni[pos] ? pos : uni[pos] = getfa(uni[pos]);
}
struct data {
int p,l,r,v;
bool operator < (const data& a) const {
return p < a.p;
}
} dat[N << 2];
typedef pair<ll,int> pli;
#define fi first
#define se second
struct node {
pli mn[2];
ll tag;
node() {
tag = 0;
mn[0] = mn[1] = pli(INF,0);
}
} t[N << 2];
void puttag(int x,ll v) {
t[x].mn[0].fi += v;
t[x].mn[1].fi += v;
t[x].tag += v;
}
void push_down(int x) {
puttag(x<<1,t[x].tag);
puttag(x<<1|1,t[x].tag);
t[x].tag = 0;
}
void push_up(node& x,node ls,node rs) {
x.mn[1].fi = INF;
if (ls.mn[0] < rs.mn[0]) {
x.mn[0] = ls.mn[0];
if (rs.mn[0].se != x.mn[0].se)
x.mn[1] = rs.mn[0];
} else {
x.mn[0] = rs.mn[0];
if (ls.mn[0].se != x.mn[0].se)
x.mn[1] = ls.mn[0];
}
if (ls.mn[1].se != x.mn[0].se)
x.mn[1] = min(x.mn[1], ls.mn[1]);
if (rs.mn[1].se != x.mn[0].se)
x.mn[1] = min(x.mn[1], rs.mn[1]);
}
void modify(int l,int r,ll v,int x=1,int lp=1,int rp=n) {
if (lp > r || l > rp) return;
if (lp >= l && rp <= r)
return (void) puttag(x,v);
push_down(x);
int mid = (lp + rp) >> 1;
modify(l,r,v,x<<1,lp,mid);
modify(l,r,v,x<<1|1,mid+1,rp);
push_up(t[x],t[x<<1],t[x<<1|1]);
}
void build(int x=1,int lp=1,int rp=n) {
t[x].tag = 0;
if (lp == rp) {
t[x].mn[0] = pli(0ll,uni[lp]);
t[x].mn[1] = pli(INF,0);
return;
}
int mid = (lp + rp) >> 1;
build(x<<1,lp,mid);
build(x<<1|1,mid+1,rp);
push_up(t[x], t[x<<1], t[x<<1|1]);
}
pli nex[N];
void solve() {
build();
for (int i = 1 ; i <= n ; ++ i)
nex[i] = pli(INF,0);
for (int i = 1, j = 1 ; i <= n ; ++ i) {
while (j <= cnt && dat[j].p <= i)
modify(dat[j].l, dat[j].r, dat[j].v), ++ j;
node tmp = t[1];
if (tmp.mn[0].se != uni[i])
nex[uni[i]] = min(nex[uni[i]], tmp.mn[0]);
else nex[uni[i]] = min(nex[uni[i]], tmp.mn[1]);
}
for (int i = 1, j ; i <= n ; ++ i) {
j = getfa(i);
if (nex[j].se == INF) continue;
if (getfa(nex[j].se) != j) {
ans += nex[j].fi;
uni[j] = getfa(nex[j].se);
}
}
}
bool check() {
for (int i = 1 ; i <= n ; ++ i)
uni[i] = getfa(i);
for (int i = 2 ; i <= n ; ++ i)
if (uni[i] != uni[i-1]) return 1;
return 0;
}
signed main() {
read(n), read(m);
for (int i = 1, a, b, c, d, e ; i <= m ; ++ i) {
read(a), read(b), read(c), read(d), read(e);
dat[++cnt] = (data) {a, c, d, e};
dat[++cnt] = (data) {b+1, c, d, -e};
dat[++cnt] = (data) {c, a, b, e};
dat[++cnt] = (data) {d+1, a, b, -e};
}
for (int i = 1 ; i <= n ; ++ i)
uni[i] = i;
sort(dat+1,dat+cnt+1);
while (check())
solve();
cout << ans << endl;
return 0;
}
小结:本题做法乍一看是几个套路的综合,但并不简单。还是要求清晰的思维,以及熟练掌握基础算法和技巧。