线段树优化建图 (CF786B、SNOI2017炸弹)
先来看板子题: CF786B
可以发现,如果对着区间内的每一个点都建一条边,然后跑最短路,我们无论是在空间还是时间复杂度上都是过不去的。因此,我们请出老朋友线段树。
参考上图。修建两棵线段树。其中一棵从父亲向左右儿子连边,若为有权图则边权为 \(0\), 以此保证每一个区间可以到达区间内部的每一个点。
第二课由儿子向父亲连边,以此保证每个点可以到达包含他的区间。
然后,单点和其对应的单点连边,以此保证互相可以获取对方信息。
之后考虑操作。
令下方线段树用于发出连边,可以参考图中黄色线段。这样便能保证可以到达对应区间内的所有点了。
(注意,在初始时,\(i\) 与 \(i + K\)只能互相到达,是不能到别的区间内部的,观察图即可发现该要求得到保证。)
值得注意的是,此处的单点应使用线段树中的节点编号,以避免混淆。
那么上板子题代码:
#include <bits/stdc++.h>
using namespace std;
#define N 3000010
#define ll long long
const int K = 500010;
template <class T>
inline void read(T& a){
T x = 0, s = 1;
char c = getchar();
while(!isdigit(c)){ if(c == '-') s = -1; c = getchar(); }
while(isdigit(c)){ x = x * 10 + (c ^ '0'); c = getchar(); }
a = x * s;
return ;
}
int n, Q, s;
int a[N]; // 每个点在线段树中的编号
struct node{
int u, v, w, next;
} t[N];
int head[N];
int bian = 0;
inline void addedge(int u, int v, int w){
t[++bian] = (node){u, v, w, head[u]}, head[u] = bian;
return ;
}
struct Segment_tree{
#define lson (o<<1)
#define rson (o<<1|1)
void build(int o, int l, int r){
if(l == r){
a[l] = o;
return ;
}
addedge(o, lson, 0); addedge(o, rson, 0);
addedge(lson + K, o + K, 0); addedge(rson + K, o + K, 0);
int mid = l + r >>1;
build(lson, l, mid);
build(rson, mid + 1, r);
return ;
}
void update(int o, int l, int r, int in, int end, int k, int w, int opt){
if(l > end || r < in) return ;
if(l >= in && r <= end){
if(opt == 2){ // v -> [l,r]
addedge(a[k] + K, o, w);
}
else addedge(o + K, a[k], w);
return ;
}
int mid = l + r >> 1;
update(lson, l, mid, in, end, k, w, opt);
update(rson, mid + 1, r, in, end, k, w, opt);
return ;
}
} tree;
struct point{
ll dis, id;
bool operator < (const point &a) const{
return dis > a.dis;
}
} ;
priority_queue <point> q;
ll dis[N];
bool vis[N];
void dij(int s){
memset(dis, 0x3f3f3f3f3f3f3f, sizeof(dis));
dis[s] = 0;
q.push((point){0, s});
while(!q.empty()){
int u = q.top().id; q.pop();
if(!vis[u]){
vis[u] = 1;
for(int i = head[u]; i; i = t[i].next){
int v = t[i].v;
if(dis[v] > dis[u] + t[i].w){
dis[v] = dis[u] + t[i].w;
if(!vis[v])q.push((point){dis[v], v});
}
}
}
}
return ;
}
signed main(){
// freopen("hh.txt", "r", stdin);
read(n), read(Q), read(s);
tree.build(1, 1, n);
for(int i = 1; i <= n; i++){
addedge(a[i], a[i] + K, 0);
addedge(a[i] + K, a[i], 0);
}
while(Q--){
ll opt, x, l, r, w;
read(opt);
if(opt == 1){
read(x), read(l), read(w);
addedge(a[x] + K, a[l], w);
}
else{
read(x), read(l), read(r), read(w);
tree.update(1, 1, n, l, r, x, w, opt);
}
}
dij(a[s] + K);
for(int i = 1; i <= n; i++){
printf("%lld ", dis[a[i]] <= 0x3f3f3f3f3f3f3f ? dis[a[i]] : -1);
}
return 0;
}
对于另外一道例题: [P5025 SNOI2017炸弹] (https://www.luogu.com.cn/problem/P5025)
首先使用 \(lower_bound\) 求出每一个炸弹可以对应的单层引爆区间,然后用单向边指向该区间。按理来说直接跑图即可,但考虑到炸弹之间可能可以互相引爆,因此先缩点去环。
还有一点,每个炸弹引爆的点是一条条线段,其交集可能被重复计算,因此要统计每一个节点所包含的区间(左右端点 \(l,r\))来计算。(说明:只有有交集线段可以合并,未交集线段不会被合并到一起,因此该方法正确。)
注意,该题其实不需要第二课线段树,因为只有单点向区间加边这一个方向。题解中的两棵线段树纯属为了练习(折磨自己)
#include <bits/stdc++.h>
using namespace std;
#define N 8000010
#define ll long long
const ll K = 2e6 + 1;
const ll mod = 1e9 + 7;
template <class T>
inline void read(T& a){
T x = 0, s = 1;
char c = getchar();
while(!isdigit(c)){ if(c == '-') s = -1; c = getchar(); }
while(isdigit(c)){ x = x * 10 + (c ^ '0'); c = getchar(); }
a = x * s;
return ;
}
struct node{
int u, v, next;
} t[N], t1[N];
int head[N];
int bian = 0;
inline void addedge(int u, int v){
t[++bian] = (node){u, v, head[u]}, head[u] = bian;
return ;
}
int bian1 = 0;
int head1[N];
inline void addedge1(int u, int v){
t1[++bian1] = (node){u, v, head1[u]}, head1[u] = bian1;
return ;
}
ll n, x[N];
ll r[N];
int a[N];
struct seg{
int l, r;
seg(){
this->l = 1e18;
this->r = 0;
return ;
}
} se[N]; // 每个节点对应的区间
seg se2[N]; // 缩点后对应的区间
ll tot = 0; // 节点总数
struct Segment_tree{
#define lson (o<<1)
#define rson (o<<1|1)
void build(int o, int l, int r){
se[o].l = l, se[o].r = r;
tot = max(tot, (ll)(o + K));
if(l == r){
// tot = max(tot, (ll)(o));
a[l] = o;
return ;
}
int mid = l + r >> 1;
build(lson, l, mid); build(rson, mid + 1, r);
addedge(o, lson); addedge(o, rson);
addedge(lson + K, o + K); addedge(rson + K, o + K);
return ;
}
void update(int o, int l, int r, int in, int end, int k){
if(l > end || r < in) return ;
if(l >= in && r <= end){ // 点向区间加边
addedge(a[k] + K, o);
return ;
}
int mid = l + r >> 1;
update(lson, l, mid, in, end, k);
update(rson, mid + 1, r, in, end, k);
return ;
}
} tree;
struct point{
ll x, r;
int id;
bool operator < (const point &a) const{
return x < a.x;
}
} p[N];
int dfn[N], low[N];
int id = 0;
int stac[N], top = 0;
int scc[N], cnt = 0;
bool vis[N];
void tarjan(int u){
low[u] = dfn[u] = ++id;
vis[u] = 1;
stac[++top] = u;
for(int i = head[u]; i; i = t[i].next){
int v = t[i].v;
if(!dfn[v]){
tarjan(v);
low[u] = min(low[u], low[v]);
} else if(vis[v]) low[u] = min(low[u], dfn[v]);
}
if(low[u] == dfn[u]){
int cur;
cnt++;
do{
cur = stac[top--];
scc[cur] = cnt;
se2[cnt].l = min(se[cur].l, se2[cnt].l);
se2[cnt].r = max(se[cur].r, se2[cnt].r);
vis[cur] = 0;
} while(cur != u);
}
return ;
}
void dfs(int u){
vis[u] = 1;
for(int i = head1[u]; i; i = t1[i].next){
int v = t1[i].v;
if(!vis[v]) dfs(v);
se2[u].l = min(se2[u].l, se2[v].l);
se2[u].r = max(se2[u].r, se2[v].r);
}
return ;
}
signed main(){
// freopen("hh.txt", "r", stdin);
read(n);
tree.build(1, 1, n);
for(int i = 1; i <= n; i++){
addedge(a[i], a[i] + K);
addedge(a[i] + K, a[i]);
}
for(int i = 1; i <= n; i++){
read(p[i].x); read(p[i].r);
p[i].id = i;
}
sort(p + 1, p + n + 1);
for(int i = 1; i <= n; i++){
ll L, R;
ll lnum = p[i].x - p[i].r;
ll rnum = p[i].x + p[i].r;
L = lower_bound(p + 1, p + n + 1, (point){lnum, 0, 0}) - p;
R = lower_bound(p + 1, p + n + 1, (point){rnum + 1, 0, 0}) - p - 1;
tree.update(1, 1, n, L, R, i);
}
for(int i = 1; i <= tot; i++)
if(!dfn[i]) tarjan(i);
for(int u = 1; u <= tot; u++){
for(int i = head[u]; i; i = t[i].next){
int v = t[i].v;
if(scc[u] != scc[v]){
addedge1(scc[u], scc[v]);
}
}
}
// memset(vis, 0, sizeof(vis));
for(int i = 1; i <= cnt; i++){
dfs(i);
}
ll ans = 0;
for(int i = 1; i <= n; i++)
ans = (ans + (ll)i * (se2[scc[a[i] + K]].r - se2[scc[a[i] + K]].l + 1) % mod) % mod;
cout << ans << endl;
return 0;
}
时隔一年,换一种更优秀的写法。CSP认证2022.12E。
#include <bits/stdc++.h>
using namespace std;
#define N 400010
#define Log 20
#define ll long long
#define y1 abdbsdhuhdwiuoh
#define int long long
template <class T>
inline T read(T& a){
T x = 0, s = 1;
char c = getchar();
while(!isdigit(c)){ if(c == '-') s = -1; c = getchar(); }
while(isdigit(c)){ x = x * 10 + (c ^ '0'); c= getchar(); }
a = x * s;
return a;
}
struct node{
int u, v;
ll w;
int next;
} t[N * Log], t1[N * Log];
int head[N * Log];
int bian = 0;
inline void addedge(int u, int v, ll w){
t[++bian] = (node){u, v, w, head[u]}, head[u] = bian;
return ;
}
int head1[N * Log];
int bian1 = 0;
inline void addedge1(int u, int v, ll w){
t1[++bian1] = (node){u, v, w, head1[u]}, head1[u] = bian1;
return ;
}
int rc[N * Log], lc[N * Log];
int rootin = 0, rootout = 0;
int id = 0;
int n, m;
#define lson rc[o]
#define rson lc[o]
void buildin(int &o, int l, int r){
if(l == r){
o = l;
return ;
}
if(!o) o = ++id;
int mid = l + r >> 1;
buildin(lson, l, mid);
buildin(rson, mid + 1, r);
addedge(o, lson, 0);
addedge(o, rson, 0);
addedge1(lson, o, 0);
addedge1(rson, o, 0);
return ;
}
void buildout(int &o, int l, int r){
if(l == r){
o = l;
return ;
}
if(!o) o = ++id;
int mid = l + r >> 1;
buildout(lson, l, mid);
buildout(rson, mid + 1, r);
addedge(lson, o, 0);
addedge(rson ,o, 0);
addedge1(o, lson, 0);
addedge1(o, rson, 0);
return ;
}
void update(int o, int l, int r, int in, int end, int val, ll k, int type){
if(l > end || r < in) return ;
if(l >= in && r <= end){
if(type == 1){
addedge(o, val, k);
addedge1(val, o, k);
}
else{
addedge(val, o, k);
addedge1(o, val, k);
}
return ;
}
int mid = l + r >> 1;
update(lson, l, mid, in, end, val, k, type);
update(rson, mid + 1, r, in, end, val, k, type);
return ;
}
ll dis[N * Log * 2];
struct Point{
ll dis;
int id;
bool operator < (const Point &a) const{
return dis > a.dis;
}
} ;
bool vis[N];
const ll mod = 1e9 + 7;
ll qpow(ll a, ll b){
ll sum = 1;
while(b){
if(b & 1) sum = (sum * a) % mod;
b >>= 1ll;
a = (a * a) % mod;
}
return sum;
}
void dij(int s){
memset(vis, 0, sizeof(vis));
for(int i = 1; i <= id; i++)
dis[i] = 1e16;
// cout << dis[1] << endl;
priority_queue <Point> q;
dis[s] = 0;
q.push((Point){0, s});
while(!q.empty()){
int now = q.top().id; q.pop();
if(!vis[now]){
vis[now] = 1;
for(int i = head[now]; i; i = t[i].next){
int v = t[i].v;
if(dis[v] > dis[now] + t[i].w){
// printf("%lld\n", dis[v]);
// printf("%lld\n", t[i].w);
dis[v] = dis[now] + t[i].w;
// cout << dis[v] << endl;
if(!vis[v]) q.push((Point){dis[v], v});
}
}
}
}
return ;
}
ll dis1[N];
void dij1(int s){
memset(vis, 0, sizeof(vis));
for(int i = 1; i <= id; i++)
dis1[i] = 1e16;
priority_queue <Point> q;
dis1[s] = 0;
q.push((Point){0, s});
while(!q.empty()){
int now = q.top().id; q.pop();
// printf("now: %lld\n", now);
if(!vis[now]){
vis[now] = 1;
for(int i = head1[now]; i; i = t1[i].next){
int v = t1[i].v;
if(dis1[v] > dis1[now] + t1[i].w){
dis1[v] = dis1[now] + t1[i].w;
if(!vis[v]) q.push((Point){dis1[v], v});
}
}
}
}
return ;
}
signed main(){
// freopen("hh.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
read(n), read(m);
id = n;
buildin(rootin, 1, n);
buildout(rootout, 1, n);
for(int i = 1; i <= m; i++){
id++;
int l1, r1, l2, r2; ll a, b;
cin >> l1 >> r1 >> l2 >> r2 >> a >> b;
update(rootout, 1, n, l1, r1, id, a * qpow(2, b), 1);
update(rootin, 1, n, l2, r2, id, a * qpow(2, b), 2);
}
dij(1);
dij1(1);
for(int i = 2; i <= n; i++){
if(dis[i] >= 1e16 || dis1[i] >= 1e16) cout << -1 << " ";
else cout << ((dis[i] % mod + dis1[i] % mod) % mod * qpow(2, mod - 2)) % mod << " ";
}
return 0;
}