E-Explorer_2019牛客暑期多校训练营(第八场)
题意
n个点,m条边,u,v,l,r表示点u到点v有一条边,且只有编号为\([l,r]\)的人能通过,问从点1到点n有哪些编号的人能通过
题解
先对\(l,r\)离散化,用第七场找中位数那题同样的形式建树,每个叶子节点表示的都是一个区间,树上每个节点维护的是,包含这个区间的边有哪些,可以用vector存下来,接着是查询哪些叶子节点可以作为答案,像线段树build一样左右一直递归下去,每递归一次都把这个节点存的边用并查集合并,思考一下就能知道,按build的方式递归下去,到叶子节点x,一定能途径所有包含叶子节点x所表示的区间的边,且不会途径不包含x表示的区间的边,这样到叶子节点是就把包含该叶子节点区间的边全部合并,此时查询点1和点n是否在一起,若在一起就可以把该叶子节点区间计入答案,查询的时候每层递归结束都要撤销并查集合并操作,以免影响递归其他叶子节点的时候影响答案,用启发式并查集合并就能撤销了
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mx = 2e5+5;
vector <int> vv;
vector <int> edge[mx<<2];
int u[mx], v[mx], l[mx], r[mx], fa[mx], sz[mx], top = 0;
int n, m;
pair <int, int> st[mx];
int getid(int x) {
return lower_bound(vv.begin(), vv.end(), x) - vv.begin() + 1;
}
int find(int x) {
return fa[x] == x ? x : find(fa[x]);
}
void merge(int u, int v) {
int fau = find(u);
int fav = find(v);
if (fau == fav) return;
if (sz[fau] > sz[fav]) swap(fau, fav);
fa[fau] = fav;
st[top++] = {fau, sz[fau]};
if (sz[fau] == sz[fav]) {
st[top++] = {fav, sz[fav]};
sz[fav]++;
}
}
void update(int L, int R, int id, int l, int r, int rt) {
if (L <= l && r <= R) {
edge[rt].push_back(id);
return;
}
int mid = (l + r) / 2;
if (L <= mid) update(L, R, id, l, mid, rt<<1);
if (mid < R) update(L, R, id, mid+1, r, rt<<1|1);
}
void cancel(int pre) {
while (top > pre) {
fa[st[top-1].first] = st[top-1].first;
sz[st[top-1].first] = st[top-1].second;
top--;
}
}
void query(int l, int r, int rt, int &ans) {
int pre = top;
for (int i = 0; i < edge[rt].size(); i++) {
int id = edge[rt][i];
merge(u[id], v[id]);
}
if (l == r) {
if (find(1) == find(n)) ans += vv[r]-vv[l-1];
cancel(pre);
return;
}
int mid = (l + r) / 2;
query(l, mid, rt<<1, ans);
query(mid+1, r, rt<<1|1, ans);
cancel(pre);
}
int main() {
for (int i = 0; i < mx; i++) {
fa[i] = i;
sz[i] = 1;
}
scanf("%d%d", &n, &m);
for (int i = 1; i <= m; i++) {
scanf("%d%d%d%d", &u[i], &v[i], &l[i], &r[i]);
r[i]++;
vv.push_back(l[i]);
vv.push_back(r[i]);
}
sort(vv.begin(), vv.end());
vv.push_back(vv[vv.size()-1] + 1);
vv.erase(unique(vv.begin(), vv.end()), vv.end());
for (int i = 1; i <= m; i++) {
update(getid(l[i]), getid(r[i])-1, i, 1, vv.size()-1, 1);
}
int ans = 0;
query(1, vv.size()-1, 1, ans);
printf("%d\n", ans);
return 0;
}