权值线段树学习笔记
权值线段树学习笔记
参考博文:
https://www.cnblogs.com/zmyzmy/p/9529234.html
权值线段树:
- 权值线段树维护数的个数,数组下标代表整个值域,如果太大可以采用离散化。
定义:
struct SegmentTree
{
int l, r;
int s; //节点p的s表示这一段值域中数的个数总和
#define l(x) tree[x].l
#define r(x) tree[x].r
#define s(x) tree[x].s
#define lson (p<<1)
#define rson (p<<1|1)
}tree[maxn<<2];
建树:
void build(int p, int l, int r)
{
l(p) = l, r(p) = r;
if(l == r)
{
//初始化
return;
}
int mid = (l + r) >> 1;
build(lson, l, mid);
build(rson, mid+1, r);
//pushup()
}
单点更新:
void change(int p, int x)
{
if(l(p) == r(p))
{
//更新数据
return;
}
int mid = (l(p) + r(p)) >> 1;
if(x <= mid) change(lson, x);
else change(rson, x);
//pushup()
}
询问整体第\(k\)小:
//询问整个区间第k小
//s(p)代表l(p)到r(p)值域中树的个数总和
int query(int p, int k)
{
if(l(p) == r(p))
return l(p); //由于数组下标维护的是值域,直接返回下标
if(k <= s(lson)) return query(lson, k); //在左子树中
else return query(rson, k - s(lson)); //在右子树中,感觉和平衡树好像
}
例题1:洛谷 https://www.luogu.org/problem/P1801
思路:
- 依题意模拟
代码:
#include<bits/stdc++.h>
#include<cstring>
using namespace std;
typedef long long ll;
const int maxn = 2e5 + 10;
int a[maxn], num[maxn], u[maxn];
int n, m, len;
struct SegmentTree
{
int l, r;
int s; //节点p的s表示这一段值域中数的个数总和
#define l(x) tree[x].l
#define r(x) tree[x].r
#define s(x) tree[x].s
#define lson (p<<1)
#define rson (p<<1|1)
}tree[maxn<<2];
void build(int p, int l, int r)
{
l(p) = l, r(p) = r;
if(l == r) return;
int mid = (l + r) >> 1;
build(lson, l, mid);
build(rson, mid+1, r);
}
void change(int p, int x)
{
if(l(p) == r(p))
{
s(p) += 1;
return;
}
int mid = (l(p) + r(p)) >> 1;
if(x <= mid) change(lson, x);
else change(rson, x);
s(p) = s(lson) + s(rson);
}
//询问整个区间第k大
//s(p)代表l(p)到r(p)值域中树的个数总和
int query(int p, int k)
{
if(l(p) == r(p))
return l(p); //由于数组下标维护的是值域,直接返回下标
if(k <= s(lson)) return query(lson, k); //在左子树中
else return query(rson, k - s(lson)); //在右子树中,感觉和平衡树好像
}
int main()
{
scanf("%d%d", &m, &n);
for(int i = 1; i <= m; i++)
{
scanf("%d", &a[i]);
num[i] = a[i];
}
for(int i = 1; i <= n; i++)
scanf("%d", &u[i]);
sort(num + 1, num + 1 + m);
len = unique(num + 1, num + 1 + m) - num - 1;
build(1, 1, len);
int cnt = 0, k = 0;
while(n != cnt)
{
cnt++;
for(int i = u[cnt-1] + 1; i <= u[cnt]; i++)
{
int y = lower_bound(num+1, num+1+len, a[i]) - num;
//y是a(i)在num里的下标
change(1, y);
}
cout << num[query(1, ++k)] << endl;
}
return 0;
}
例题2:洛谷1908:逆序对(权值线段树写法)
题意描述:
- 求逆序对数目。
思路:
- 见注释
代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 5e5 + 10;
ll ans;
int n, a[maxn], num[maxn], len;
struct SegmentTree
{
int l, r;
ll s;
#define l(x) tree[x].l
#define r(x) tree[x].r
#define lson (p<<1)
#define rson (p<<1|1)
#define s(x) tree[x].s
}tree[maxn<<2];
inline void pushup(int p){
s(p) = s(lson) + s(rson);
}
inline void build(int p, int l, int r)
{
l(p) = l, r(p) = r;
if(l == r) return;
int mid = (l + r) >> 1;
build(lson, l, mid);
build(rson, mid + 1, r);
}
inline void change(int p, int x)
{
if(l(p) == r(p))
{
s(p)++;
return;
}
int mid = (l(p) + r(p)) >> 1;
if(x <= mid) change(lson, x);
else change(rson, x);
pushup(p);
}
ll query(int p, int x)
{
if(l(p) == r(p)) return s(p);
int mid = (l(p) + r(p)) >> 1;
if(x <= mid) return query(lson, x) + s(rson);
else return query(rson, x);
}
int main()
{
scanf("%d", &n);
for(int i = 1; i <= n; i++)
{
scanf("%d", &a[i]);
num[i] = a[i];
}
build(1, 1, n);
sort(num + 1, num + 1 + n);
len = unique(num + 1, num + 1 + n) - num - 1;
for(int i = 1; i <= n; i++)
{
int p = lower_bound(num + 1, num + 1 + n, a[i]) - num;
a[i] = p;
}
for(int i = 1; i <= n; i++) //枚举每个a(i)作为右端点
{ //看树中有多少比他大的数字
ans += query(1, a[i] + 1); //寻找比当前数大的数字的个数
//+1是因为要过滤掉等于a(i)的
change(1, a[i]); //在权值线段树中加上该节点
}
cout << ans << endl;
return 0;
}
例题3:hdu_4217
题意描述:
- 给定一个\(1\)到\(n\)的序列。每次操作查询序列第\(k\)小的数字加入答案并拿走这个数字,问最后拿走数字的总和是多少
- \(n\leq 3e5\)
#include<bits/stdc++.h>
using namespace std;
const int maxn = 3e5 + 10;
int T, n, m, k, cas;
struct SegmentTree
{
int l, r;
int s;
#define l(x) tree[x].l
#define r(x) tree[x].r
#define lson (p<<1)
#define rson (p<<1|1)
#define s(x) tree[x].s
}tree[maxn<<2];
void pushup(int p){
s(p) = s(lson) + s(rson);
}
void build(int p, int l, int r)
{
l(p) = l, r(p) = r;
if(l == r) {s(p) = 1; return;}
int mid = (l + r) >> 1;
build(lson, l, mid);
build(rson, mid+1, r);
pushup(p);
}
int query(int p, int k)
{
if(l(p) == r(p)) return l(p);
if(k <= s(lson)) return query(lson, k);
else return query(rson, k - s(lson));
}
void change(int p, int x, int val)
{
if(l(p) == r(p))
{
s(p) = val;
return;
}
int mid = (l(p) + r(p)) >> 1;
if(x <= mid) change(lson, x, val);
else change(rson, x, val);
pushup(p);
}
int main()
{
scanf("%d", &T);
while(T--)
{
scanf("%d%d", &n, &m);
build(1, 1, n);
long long ans = 0;
while(m--)
{
scanf("%d", &k);
int num = query(1, k);
ans += num;
change(1, num, 0);
}
printf("Case %d: %lld\n", ++cas, ans);
}
return 0;
}