线段树合并[学习笔记]
前置知识: 动态开点线段树/主席树
线段树合并,跟名字一样,就是合并两颗线段树的信息…
合并两颗线段树的方法是 递归下去一直合并两个如果俩节点都有…就新建一个维护当前两个的信息然后取而代之
如果只有一个有 那就返回那一个就够了…
这个blog内借鉴了一些洛谷的题解…因为作者实在菜的可怜 (懒)
先放个板子题吧
求子树几个比他权值大的
inline int Merge(int u , int v) {
if(!u || ! v) return u | v ;
int t = ++ cnt ;
sum[t] = sum[u] + sum[v] ;
ls[t] = Merge(ls[u] , ls[v]) ;
rs[t] = Merge(rs[u] , rs[v]) ;
return t ;
}
就像是这样…
然后每次递归到最底层 下属信息合并…然后查询 最后输出,没了。
// luogu-judger-enable-o2
//Isaunoya
#include<bits/stdc++.h>
using namespace std ;
inline int read() { register int x = 0 ; register int f = 1 ; register char c = getchar() ;
for( ; ! isdigit(c) ; c = getchar()) if(c == '-') f = -1 ;
for( ; isdigit(c) ; c = getchar()) x = (x << 1) + (x << 3) + (c & 15) ;
return x * f ;
} int st[105] ;
template < typename T > inline void write(T x , char c = '\n') { int tp = 0 ;
if(x == 0) return (void) puts("0") ;
if(x < 0) putchar('-') , x = -x ;
for( ; x ; x /= 10) st[++ tp] = x % 10 ;
for( ; tp ; tp --) putchar(st[tp] + '0') ;
putchar(c) ;
}
//#define Online_Judge
const int N = 1e5 + 10 ;
int val[N] , b[N] , tot = 0 ;
int fa[N] , head[N] , nxt[N] ;
int ls[N << 5] , rs[N << 5] , rt[N << 5] , cnt = 0 ;
int sum[N << 5] ;
int n , ans[N] ;
inline int get(int x) {
return lower_bound(b + 1 , b + tot + 1 , x) - b - 1 ;
}
inline void Modify(int & u , int l , int r , int pos) {
if(! u) u = ++ cnt ;
sum[u] ++ ;
if(l == r) return ;
int mid = l + r >> 1 ;
if(mid >= pos) Modify(ls[u] , l , mid , pos) ;
else Modify(rs[u] , mid + 1 , r , pos) ;
}
inline int query(int u , int l ,int r , int x) {
if(! u) return 0 ;
if(l >= x) return sum[u] ;
int mid =l+r>>1;
if(mid>=x)return query(ls[u],l,mid,x) + query(rs[u],mid+1,r,x) ;
return query(rs[u] , mid + 1 , r , x) ;
}
inline int Merge(int u , int v) {
if(!u || ! v) return u | v ;
int t = ++ cnt ;
sum[t] = sum[u] + sum[v] ;
ls[t] = Merge(ls[u] , ls[v]) ;
rs[t] = Merge(rs[u] , rs[v]) ;
return t ;
}
inline void Dfs(int u) {
for(register int i = head[u] ; i ; i = nxt[i]) {
Dfs(i) ;
rt[u] = Merge(rt[u] , rt[i]) ;
}
ans[u] = query(rt[u] , 1 , tot , val[u] + 1) ;
Modify(rt[u] , 1 , tot, val[u]) ;
}
signed main() {
#ifdef Online_Judge
freopen("testdata.in" , "r" , stdin) ;
freopen("testdata2.out" , "w" , stdout) ;
#endif
n = read() ;
for(register int i = 1 ; i <= n ; i ++) b[i] = val[i] = read() ;
for(register int i = 2 ; i <= n ; i ++) {
int u ;
u = fa[i] = read() ;
nxt[i] = head[fa[i]] ;
head[fa[i]] = i ;
}
sort(b + 1 , b + n + 1) ;
tot = unique(b + 1 , b + n + 1) - b - 1 ;
for(register int i = 1 ; i <= n ; i ++) val[i] = get(val[i]) ;
Dfs(1) ;
for(register int i = 1 ; i <= n ; i ++)
write(ans[i]) ;
return 0 ;
}
放几个题吧…
几乎相同的套路…
首先,我们可以先理解一下题面:固定一个a,找到一个b,c就是a与b的公共子树中的某个点。
那么,我们显然可以把这个b分成两类,第一种是在a上面的,第二种在a下面的。
对于b在a上面的情况,显然,c一定是a的子树中的某个点,答案即\(min(k,dep_a)*size_a\)
这颗线段树中的\([dep_x+1,dep_x+K]\)区间的sum就是这个\(x\)点的答案
// luogu-judger-enable-o2
#include <bits/stdc++.h>
using namespace std ;
const int N = 3e5 + 10 ;
const int xN = N * 21 ;
inline int read(){
register int x = 0 ;
register int f = 1 ;
register char c = getchar() ;
for( ; ! isdigit(c) ; c = getchar())
if(c == '-') f = -1 ;
for( ; isdigit(c) ; c = getchar())
x = (x << 1) + (x << 3) + (c & 15) ;
return x * f ;
}
#define int long long
int n ;
struct node {
int v ;
int nxt ;
};
node e[N << 1] ;
int head[N] ;
int tot = 0 ;
int sum[xN] ;
int rt[xN] ;
int ls[xN] ;
int rs[xN] ;
int d[N] ;
inline void Ins(int & o , int l ,int r , int p , int v) {
if(! o ) o = ++ tot ;
sum[o] += v ;
if(l == r) return ;
int mid = l + r >> 1 ;
if(p <= mid) Ins(ls[o] , l , mid ,p , v) ;
else Ins(rs[o] , mid + 1 , r , p , v ) ;
}
inline long long Query(int a , int b , int l, int r , int k) {
if(! k) return 0 ;
if(a <= l && r <= b) return sum[k] ;
int mid = l + r >> 1 ;
long long ans = 0 ;
if(a <= mid) ans += Query(a , b , l , mid , ls[k]) ;
if(b > mid) ans += Query(a , b , mid + 1 , r , rs[k]) ;
return ans ;
}
int size[N] ;
inline int Merge(int x , int y , int l ,int r) {
if(! x || ! y) return x | y ;
int mid = l + r >> 1 ;
int t = ++ tot ;
sum[t] = sum[x] + sum[y] ;
ls[t] = Merge(ls[x] , ls[y] , l , mid) ;
rs[t] = Merge(rs[x] , rs[y] , mid + 1 , r) ;
return t ;
}
int cnt = 0 ;
inline void Add(int u , int v) {
e[++ cnt].v = v ;
e[cnt].nxt = head[u] ;
head[u] = cnt ;
return ;
}
inline void Dfs(int u , int fa) {
size[u] = 1 ; d[u] = d[fa] + 1 ;
for(register int i = head[u] ; i ; i = e[i].nxt) {
int v = e[i].v ;
if(v == fa) continue ;
Dfs(v , u) ;
size[u] += size[v] ;
}
Ins(rt[u] , 1 , n , d[u] , size[u] - 1) ;
if(fa) rt[fa] = Merge(rt[fa] , rt[u] , 1 , n) ;
}
signed main() {
n = read() ; int q = read() ;
for(register int i = 1 ; i <= n - 1 ; i ++) {
int u =read() ,v = read() ;
Add(u , v) ;
Add(v , u) ;
}
Dfs(1 , 0) ;
for(register int i = 1 ; i <= q ; i ++) {
int x = read() , y = read() ;
long long ans = Query(d[x] + 1 , d[x] + y , 1 , n , rt[x]) + 1LL * (size[x] - 1) * min(d[x] - 1 , y) ;
printf("%lld\n" , ans) ;
}
return 0 ;
}
[POI2011]ROT-Tree Rotations
poi!
给一棵n个叶子的二叉树,可以交换每个点的左右子树,要求前序遍历叶子的逆序对最少。
\(1\leq n \leq 100000\)
这道题主要就是权值线段树合并的一个过程。我们对每个叶子结点开一个权值线段树,然后逐步合并。
考虑到一件事情:如果在原树有一个根节点x,和其左儿子\(ls\),右儿子\(rs\)。我们要合并的是\(ls\)的权值线段树和\(rs\)的权值线段树,得到\(x\)的所有叶节点的权值线段树。
发现交换\(ls\)和\(rs\)并不会对原树更上层之间的逆序对产生影响,于是我们只需要每次合并都让逆序对最少。
于是我们的问题转化为了给定两个权值线段树,问把它们哪个放在左边可以使逆序对个数最小,为多少。
考虑我们合并到一个节点,其权值范围为\([l,r]\),中点为\(mid\)。这个时候我们有两棵树,我们要分别计算出某棵树在左边的时候和某棵树在右边的时候的逆序对个数。事实上我们只需要处理权值跨过中点\(mid\)的逆序对,那么所有的逆序对都会在递归过程中被处理仅一次(类似一个分治的过程)。而我们这个时候可以轻易的算出两种情况的逆序对个数,不交换的话是左边那棵树的右半边乘上右边那棵树的的左半边的大小;交换的话则是左边那棵树的左半边乘上左边那棵树的的右半边的大小。
然后每次合并由于都可以交换左右子树,我们就把这次合并中交换和不交换的情况计算一下,取最小值累积就可以了。
空间复杂度:\(O(n \log n)\),时间复杂度 \(O(n \log n)\)
// luogu-judger-enable-o2
//Isaunoya
#include<bits/stdc++.h>
using namespace std ;
inline int read() { register int x = 0 ; register int f = 1 ; register char c = getchar() ;
for( ; ! isdigit(c) ; c = getchar()) if(c == '-') f = -1 ;
for( ; isdigit(c) ; c = getchar()) x = (x << 1) + (x << 3) + (c & 15) ;
return x * f ;
} int st[105] ;
template < typename T > inline void write(T x , char c = '\n') { int tp = 0 ;
if(x == 0) return (void) puts("0") ;
if(x < 0) putchar('-') , x = -x ;
for( ; x ; x /= 10) st[++ tp] = x % 10 ;
for( ; tp ; tp --) putchar(st[tp] + '0') ;
putchar(c) ;
}
//#define Online_Judge
const int N = 2e5 + 10 ;
int n ;
int ls[N << 5] , rs[N << 5] ;
int cnt = 0 ;
long long val[N << 5] ;
inline void Upd(int l , int r , int & pos , int v) {
if(! pos) pos = ++ cnt ;
val[pos] ++ ;
if(l == r) return ;
int mid = l + r >> 1 ;
if(v <= mid) Upd(l , mid , ls[pos] , v) ;
else Upd(mid + 1 , r , rs[pos] , v) ;
}
long long ans1 = 0 , ans2 = 0 ;
//inline int Merge(int x , int y) {
// if(! x || ! y) return x | y ;
// val[x] += val[y] ;
// ans1 += val[rs[x]] * val[ls[y]] ;
// ans2 += val[ls[x]] * val[rs[y]] ;
// ls[x] = Merge(ls[x] , ls[y]) ;
// rs[x] = Merge(rs[x] , rs[y]) ;
//}
inline void Merge(int & x , int y) {
if(! x || ! y ) {
x = x + y ;
return ;
}
val[x] += val[y] ;
ans1 += val[rs[x]] * val[ls[y]] ;
ans2 += val[ls[x]] * val[rs[y]] ;
Merge(ls[x] , ls[y]) ;
Merge(rs[x] , rs[y]) ;
}
long long ans = 0 ;
inline void Dfs(int &x) {
int tmp = read() ; x = 0 ;
int Ls , Rs ;
if(! tmp) {
Dfs(Ls) , Dfs(Rs) ;
ans1 = ans2 = 0 ;
x = Ls ;
Merge(x , Rs) ;
ans += min(ans1 , ans2) ;
}
else Upd(1 , n , x , tmp) ;
}
signed main() {
#ifdef Online_Judge
freopen("testdata.in" , "r" , stdin) ;
freopen("testdata2.out" , "w" , stdout) ;
#endif
n = read() ;
int tmp = 0 ;
Dfs(tmp) ;
write(ans) ;
return 0 ;
}
其实还有一题…
树剖上面线段树合并…
//Isaunoya
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast")
#include<bits/stdc++.h>
using namespace std ;
inline int read() {
register int x = 0 ;
register int f = 1 ;
register char c = getchar() ;
for( ; ! isdigit(c) ; c = getchar()) if(c == '-') f = -1 ;
for( ; isdigit(c) ; c = getchar()) x = (x << 1) + (x << 3) + (c & 15) ;
return x * f ;
}
int st[105] ;
template < typename T > inline void write(T x , char c = '\n') {
int tp = 0 ;
if(x == 0) return (void) puts("0") ;
if(x < 0) putchar('-') , x = -x ;
for( ; x ; x /= 10) st[++ tp] = x % 10 ;
for( ; tp ; tp --) putchar(st[tp] + '0') ;
putchar(c) ;
}
//#define Online_Judge
int n , m ;
const int N = 1e5 + 10 ;
int lc[N << 6] ;
int rc[N << 6] ;
int rt[N << 6] ;
int mx[N << 6] , id[N << 6] ;
int cnt = 0 ;
inline void Push_up(int rt) {
if(mx[lc[rt]] >= mx[rc[rt]]) {
mx[rt] = mx[lc[rt]] ;
id[rt] = id[lc[rt]] ;
} else {
mx[rt] = mx[rc[rt]] ;
id[rt] = id[rc[rt]] ;
}
}
inline int Merge(int a , int b , int l , int r) {
if(! a || ! b) return a | b ;
if(l == r) {
mx[a] += mx[b] , id[a] = l ;
return a ;
}
int mid = l + r >> 1 ;
lc[a] = Merge(lc[a] , lc[b] , l , mid) ;
rc[a] = Merge(rc[a] , rc[b] , mid + 1 , r) ;
return Push_up(a) , a ;
}
inline void Insert(int & p , int l , int r , int pos , int v) {
if(! p) p = ++ cnt ;
if(l == r) {
id[p] = l ;
mx[p] += v ;
return ;
}
int mid = l + r >> 1 ;
if(pos <= mid) Insert(lc[p] , l , mid , pos , v) ;
else Insert(rc[p] , mid + 1 , r , pos , v) ;
Push_up(p) ;
return ;
}
struct node {
int v ;
int nxt ;
};
int tot = 0 ;
int head[N] ;
node e[N << 1] ;
inline void Add(int u , int v) {
e[++ tot].v = v ;
e[tot].nxt = head[u] ;
head[u] = tot ;
return ;
}
int top[N] , fa[N] ;
int d[N] , idx[N] ;
int size[N] , son[N] ;
int Idx = 0 ;
inline void Dfs1(int u) {
size[u] = 1 ;
for(register int i = head[u] ; i ; i = e[i].nxt) {
int v = e[i].v ;
if(v == fa[u]) continue ;
fa[v] = u ;
d[v] = d[u] + 1 ;
Dfs1(v) ;
size[u] += size[v] ;
if(size[v] > size[son[u]]) son[u] = v ;
}
}
inline void Dfs2(int u , int t) {
idx[u] = ++ Idx ;
top[u] = t ;
if(! son[u]) return ;
Dfs2(son[u] , t) ;
for(register int i = head[u] ; i ; i = e[i].nxt) {
int v = e[i].v ;
if(v ^ fa[u] && v ^ son[u]) Dfs2(v , v) ;
}
}
inline int GetLca(int x , int y) {
int fx = top[x] ;
int fy = top[y] ;
while(fx ^ fy) {
if(d[fx] < d[fy]) swap(fx , fy) , swap(x , y) ;
x = fa[fx] ;
fx = top[x] ;
}
if(d[x] > d[y]) swap(x , y) ;
return x ;
}
int ans[N] ;
inline void GetAns(int u , int fa) {
for(register int i = head[u ] ; i ; i = e[i].nxt) {
int v = e[i].v ;
if(v == fa) continue ;
GetAns(v , u) ;
rt[u] = Merge(rt[u] , rt[v] , 1 , N) ;
}
ans[u] = id[rt[u]] ;
if(mx[rt[u]] == 0) ans[u] = 0 ;
}
signed main() {
#ifdef Online_Judge
freopen("testdata.in" , "r" , stdin) ;
freopen("testdata2.out" , "w" , stdout) ;
#endif
n = read() ;
m = read() ;
for(register int i = 1 ; i <= n - 1 ; i ++) {
int u , v ;
u = read() , v = read() ;
Add(u , v) ;
Add(v , u) ;
}
Dfs1(1) ;
Dfs2(1 , 1) ;
for(register int i = 1 ; i <= m ; i ++) {
int x = read() , y = read() , z = read() ;
int lca = GetLca(x , y) ;
Insert(rt[lca] , 1 , N , z , - 1) ;
Insert(rt[fa[lca]] , 1 , N , z , - 1) ;
Insert(rt[x] , 1 , N , z , 1) ;
Insert(rt[y] , 1 , N , z , 1) ;
}
GetAns(1 , 0) ;
for(register int i = 1 ; i <= n ; i ++) write(ans[i]) ;
return 0 ;
}