Splay+离散化 - HDU 3436 - Queue-jumpers
Splay+离散化 - HDU 3436 - Queue-jumpers
因为太菜怕离散化写错,一开始尝试着写在线算法竟然还过了(玄学复杂度不会证明,貌似破坏了splay的期望logN)。本文先介绍离线的正解,文末附在线算法的代码。
1. 离线算法
离散化其实挺好想,数据范围N=1e8肯定不可能开一个1e8的splay。离散化之后就是常规的splay处理。这题多了一个少见的top操作,实现方法就是splay(L, 0) splay(R ,L)然后在 L 和 R之间插入新节点。
这里贴一个他人的代码,代码来源:Przz
/*
hdu 3436 splay树+离散化*
本来以为很好做的,写到中途发现10^8,GG
然后参考了下,把操作不用的区间缩点离散化处理
然后就是删除点,感觉自己开始写的太麻烦了,将要删除的点移动到根,如果没有儿子直接删掉,
否则将右树的最小点移到ch[r][1]使右树没有左子树,然后把根的左树接到右树上
hhh-2016-02-20 22:22:22
*/
#include <functional>
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <map>
#include <cmath>
using namespace std;
typedef long long ll;
typedef long double ld;
#define key_value ch[ch[root][1]][0]
const int maxn = 200010;
int ch[maxn][2];
int pre[maxn],key[maxn],siz[maxn],num[maxn];
int root,tot,cnt,n,TOT;
int posi[maxn];
char qry[maxn][10];
int op[maxn];
int te[maxn];
int s[maxn],e[maxn];
void Treaval(int x) {
if(x) {
Treaval(ch[x][0]);
printf("结点%2d:左儿子 %2d 右儿子 %2d 父结点 %2d size = %2d ,key = %2d num= %2d \n",x,ch[x][0],ch[x][1],pre[x],siz[x],key[x],num[x]);
Treaval(ch[x][1]);
}
}
void debug() {printf("%d\n",root);Treaval(root);}
void push_up(int r)
{
int lson = ch[r][0],rson = ch[r][1];
siz[r] = siz[lson] + siz[rson] + num[r];
}
void push_down(int r)
{
}
void inOrder(int r)
{
if(!r)return;
inOrder(ch[r][0]);
printf("%d ",key[r]);
inOrder(ch[r][1]);
}
void NewNode(int &r,int far,int k)
{
r = ++tot;
posi[k] = r;
key[r] = k;
pre[r] = far;
ch[r][0] = ch[r][1] = 0;
siz[r] = num[r] = e[k]-s[k]+1;
}
void rotat(int x,int kind)
{
int y = pre[x];
push_down(y);
push_down(x);
ch[y][!kind] = ch[x][kind];
pre[ch[x][kind]] = y;
if(pre[y])
ch[pre[y]][ch[pre[y]][1]==y] = x;
pre[x] = pre[y];
ch[x][kind] = y;
pre[y] = x;
push_up(y);
}
void build(int &x,int l,int r,int far)
{
if(l > r) return ;
int mid = (l+r) >>1;
NewNode(x,far,mid);
build(ch[x][0],l,mid-1,x);
build(ch[x][1],mid+1,r,x);
push_up(x);
}
void splay(int r,int goal)
{
push_down(r);
while(pre[r] != goal)
{
if(pre[pre[r]] == goal)
{
push_down(pre[r]);
push_down(r);
rotat(r,ch[pre[r]][0] == r);
}
else
{
push_down(pre[pre[r]]);
push_down(pre[r]);
push_down(r);
int y = pre[r];
int kind = ch[pre[y]][0] == y;
if(ch[y][kind] == r)
{
rotat(r,!kind);
rotat(r,kind);
}
else
{
rotat(y,kind);
rotat(r,kind);
}
}
}
push_up(r);
if(goal == 0)
root = r;
}
int Bin(int x)
{
int l = 0,r = TOT-1;
while(l<=r)
{
int mid=(l+r)>>1;
if(s[mid]<=x&&e[mid]>=x)
return mid;
if(e[mid]<x)
l=mid+1;
else
r=mid-1;
}
}
int get_min(int r)
{
push_down(r);
while(ch[r][0])
{
r = ch[r][0];
push_down(r);
}
return r;
}
int get_kth(int r,int k)
{
int t = siz[ch[r][0]];
if(k<=t)
return get_kth(ch[r][0],k);
else if(k<=t+num[r])
return s[key[r]]+(k-t)-1;
else
return get_kth(ch[r][1],k-t-num[r]);
}
void delet()
{
if(ch[root][0] == 0 || ch[root][1] == 0)
{
root = ch[root][0] + ch[root][1];
pre[root] = 0;
return;
}
int k = get_min(ch[root][1]);
splay(k,root);
ch[ch[root][1]][0] = ch[root][0];
root = ch[root][1];
pre[ch[root][0]] = root;
pre[root] = 0;
push_up(root);
}
int top(int t)
{
int r = Bin(t);
r = posi[r];
splay(r,0);
delet();
splay(get_min(root),0);
ch[r][0] = 0;
ch[r][1] = root;
pre[root] = r;
root = r;
pre[root] = 0;
push_up(root);
// debug();
}
int Query(int x)
{
int r = Bin(x);
r = posi[r];
splay(r,0);
return siz[ch[r][0]]+1;
}
int get_rank(int x,int k)
{
int t = siz[ch[x][0]];
if(k <= t)
return get_rank(ch[x][0],k);
else
return get_rank(ch[x][1],k-t);
}
void ini(int n)
{
tot = root = 0;
ch[root][0] = ch[root][1] = pre[root] = siz[root] = num[root] = 0 ;
build(root,0,n-1,0);
push_up(ch[root][1]);
push_up(root);
//inOrder(root);
}
int main()
{
int q,T;
int cas =1;
scanf("%d",&T);
while(T--)
{
scanf("%d%d",&n,&q) ;
if(n == -1 && q == -1)
break;
int tcn = 0;
printf("Case %d:\n",cas++);
for(int i =1; i <= q; i++)
{
scanf("%s%d",qry[i],&op[i]);
if(qry[i][0] == 'T' || qry[i][0] == 'Q')
te[tcn++] = op[i];
}
te[tcn++] = n;
te[tcn++] = 1;
sort(te,te+tcn);
TOT= 0;
s[TOT] = te[0],e[TOT] = te[0],TOT++;
for(int i = 1; i < tcn; i++)
{
if(te[i] != te[i-1] && i)
{
if(te[i] - te[i-1] > 1)
{
s[TOT] = te[i-1]+1;
e[TOT] = te[i]-1;
TOT++;
}
s[TOT] = te[i];
e[TOT] = te[i];
TOT++;
}
}
ini(TOT);
//debug();
for(int i = 1; i <= q; i++)
{
if(qry[i][0]=='T')
top(op[i]);
else if(qry[i][0]=='Q')
printf("%d\n", Query(op[i]));
else
printf("%d\n",get_kth(root,op[i]));
}
//debug();
}
return 0;
}
2. 在线算法(不能证明复杂度)
2.1 节点定义
还是我们提到的问题,本题数据范围N=1e8,不可能为每个点单独开一个节点存。考虑用一个节点存一个线短。只有需要访问某一个线段的子线段时,才将这个线段拆分掉。
struct Node{
int l, r; // 所代表的区间范围
int p, size, s[2];
void clear(){
l = r = p = size = s[0] = s[1] = 0;
}
void init(int _l, int _r, int _p){
s[0] = s[1] = 0;
l = _l;
r = _r;
p = _p;
size = _r-_l+1;
}
}tr[M];
我们将线段的左右关系当作Splay插入时的大小关系
int insert(int l, int r){
int u = rt, p = 0;
while(u){
p = u;
if(l > tr[u].r){
u = tr[u].s[1];
}else{
u = tr[u].s[0];
}
}
u = ++idx;
tr[u].init(l, r, p);
if(p){
if(l > tr[p].r){
tr[p].s[1] = u;
}else{
tr[p].s[0] = u;
}
}
splay(u, 0);
return u;
}
初始化时,插入左右哨兵和整个区间
int L = insert(-INF, -INF);
int R = insert(INF, INF);
int u = insert(1, n);
2.2 拆分区间
假设我现在需要进行top(x)
操作,这个操作实际上需要划分出[x, x]
这段区间。这个区间可能已经在之前的操作中被划分出来了,也可能没有被划分出来。
我们可以用一个map
来维护目前已经划分出的区间
map<int, int> dict;
// key: 目前已经划分的各段区间的右端点
// value: 该区间的splay下标
接下来就可以进行区间的拆分了
void split_node(int pos, int l, int r){
// 将某段区间[L,R] 分裂出 [l,r] 和 剩余部分
splay(pos, 0); // 这样保证了该节点的size区间端点更新不会影响其祖先
if(l == tr[pos].l){
int tmp = tr[pos].r;
tr[pos].r = r;
push_up(pos);
// 插入 [r+1, R]
int u = add_suc(r+1, tmp, pos);
// [L,R] -> [L,r] [r+1,R]
dict.find(tmp)->second = u; // R 存到u中
dict.insert(make_pair(r, pos)); // r 存到pos中
}else if(r == tr[pos].r){
int tmp = tr[pos].l;
tr[pos].l = l;
push_up(pos);
int u = add_pre(tmp, l-1, pos);
// [L,R] -> [L,l-1] [l,R]
dict.insert(make_pair(l-1, u));
}else{
// [L,R] -> [L, l-1], [l,r] , [r+1, R]
int tmpl = tr[pos].l, tmpr = tr[pos].r;
tr[pos].l = l;
tr[pos].r = r;
push_up(pos);
int u = add_pre(tmpl, l-1, pos);
int v = add_suc(r+1, tmpr, pos);
dict.find(tmpr)->second = v;
dict.insert(make_pair(l-1, u));
dict.insert(make_pair(r, pos));
}
}
其中,add_pre
为将一个新节点插入到另一个节点的前驱,add_suc
为将一个新节点插入到另一个节点的后继。这个操作不是常规的splay插入操作,因此无法保证splay仍能保持期望log(N)的复杂度(所以我说这个做法比较玄学)。
void push_up(int x){
tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + tr[x].r - tr[x].l + 1;
}
void push_up_to_root(int x){
while(x){
push_up(x);
x = tr[x].p;
}
}
int add_pre(int l, int r, int p){
int u = ++idx;
tr[u].init(l, r, p);
tr[u].s[0] = tr[p].s[0];
tr[tr[p].s[0]].p = u;
tr[p].s[0] = u;
push_up_to_root(u);
splay(u, 0);
return u;
}
int add_suc(int l, int r, int p){
int u = ++idx;
tr[u].init(l, r, p);
tr[u].s[1] = tr[p].s[1];
tr[tr[p].s[1]].p = u;
tr[p].s[1] = u;
push_up_to_root(u);
splay(u, 0);
return u;
}
2.3 top(x)
top操作需要将区间[x, x]
单独划分出来,把原本的[x, x]
删掉,然后将新的[x, x]
插入到[-INF, -INF]
的后继
那么如何判断x是否被划分出来了呢?
void top(int x){
map<int, int> :: iterator it_x = dict.find(x);
if(it_x == dict.end()){
map<int, int> :: iterator it_lb = dict.lower_bound(x);
split_node(it_lb->second, x, x);
it_x = dict.find(x);
}else if(tr[it_x->second].l != tr[it_x->second].r){
split_node(it_x->second, x, x);
it_x = dict.find(x);
}
int u = it_x->second;
int xl = get_pre(u);
int xr = get_suc(u);
splay(xl, 0);
splay(xr, xl);
tr[xr].s[0] = 0;
push_up(xr);
push_up(xl);
tr[u].l = -123456;
// 更新 x 的位置
int v = insert(-INF+1,-INF+1);
splay(L, 0);
splay(v, L);
splay(R, v);
tr[v].l = tr[v].r = x;
it_x->second = v;
}
2.4 query(x)
这个操作很容易实现,只需要先用map查询节点的splay下标,然后将其转至根,统计左儿子大小即可。注意哨兵[-INF, -INF]
会占据一位。
int query(int x){
map<int, int> :: iterator it_x = dict.find(x);
if(it_x == dict.end()){
map<int, int> :: iterator it_lb = dict.lower_bound(x);
split_node(it_lb->second, x, x);
it_x = dict.find(x);
}else if(tr[it_x->second].l != tr[it_x->second].r){
split_node(it_x->second, x, x);
it_x = dict.find(x);
}
int u = it_x->second;
splay(u, 0);
return tr[tr[u].s[0]].size;
}
2.5 rank_x(x)
这个操作也很简单,与常规的getk(k)稍微有一点区别。
int rank_x(int x){
++x;
int u = rt;
while(u){
if(tr[tr[u].s[0]].size >= x){
u = tr[u].s[0];
}else if(tr[tr[u].s[0]].size + tr[u].r - tr[u].l + 1 >= x){
x -= tr[tr[u].s[0]].size;
return tr[u].l + x - 1;
}else{
x -= tr[tr[u].s[0]].size + tr[u].r - tr[u].l + 1;
u = tr[u].s[1];
}
}
return -1;
}
2.6 完整代码
#include <bits/stdc++.h>
using namespace std;
const int M = 10+1;
const int INF = 0x3fffffff;
int T;
int L, R;
int n, m, num;
char op[10];
map<int,int> dict;
// [L,R]段所存的下标(这个L不需要存储)
// 当前所有存在的区间中,右端点>=L的最小值
int rt, idx;
struct Node{
int l, r; // 所代表的区间范围
int p, size, s[2];
void clear(){
l = r = p = size = s[0] = s[1] = 0;
}
void init(int _l, int _r, int _p){
s[0] = s[1] = 0;
l = _l;
r = _r;
p = _p;
size = _r-_l+1;
}
}tr[M];
int ws(int x){
return tr[tr[x].p].s[1] == x;
}
void push_up(int x){
tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + tr[x].r - tr[x].l + 1;
}
void push_up_to_root(int x){
while(x){
push_up(x);
x = tr[x].p;
}
}
void rotate(int x){
int y = tr[x].p;
int z = tr[y].p;
int k = ws(x);
tr[z].s[ws(y)] = x;
tr[y].p = x;
tr[y].s[k] = tr[x].s[k^1];
tr[tr[x].s[k^1]].p = y;
tr[x].p = z;
tr[x].s[k^1] = y;
push_up(y);
push_up(x);
}
void splay(int x, int k){
while(tr[x].p != k){
int y = tr[x].p;
int z = tr[y].p;
if(z != k){
if(ws(x) ^ ws(y)){
rotate(x);
}else{
rotate(y);
}
}
rotate(x);
}
if(!k) rt = x;
}
int insert(int l, int r){ // 我们默认,插入节点的时候已经将原区间划分开了
int u = rt, p = 0;
while(u){
p = u;
if(l > tr[u].r){
u = tr[u].s[1];
}else{
u = tr[u].s[0];
}
}
u = ++idx;
tr[u].init(l, r, p);
if(p){
if(l > tr[p].r){
tr[p].s[1] = u;
}else{
tr[p].s[0] = u;
}
}
splay(u, 0);
return u;
}
int add_pre(int l, int r, int p){
int u = ++idx;
tr[u].init(l, r, p);
tr[u].s[0] = tr[p].s[0];
tr[tr[p].s[0]].p = u;
tr[p].s[0] = u;
push_up_to_root(u);
splay(u, 0);
return u;
}
int add_suc(int l, int r, int p){
int u = ++idx;
tr[u].init(l, r, p);
tr[u].s[1] = tr[p].s[1];
tr[tr[p].s[1]].p = u;
tr[p].s[1] = u;
push_up_to_root(u);
splay(u, 0);
return u;
}
// ERROR:直接插入会改变某些顺序
void split_node(int pos, int l, int r){ // 将某段区间[L,R] 分裂出 [l,r] 和 剩余部分
splay(pos, 0); // 这样保证了该节点的size区间端点更新不会影响其祖先
if(l == tr[pos].l){
int tmp = tr[pos].r;
tr[pos].r = r;
push_up(pos);
// 插入 [r+1, R]
int u = add_suc(r+1, tmp, pos);
// [L,R] -> [L,r] [r+1,R]
dict.find(tmp)->second = u; // R 存到u中了
dict.insert(make_pair(r, pos)); // r 存到pos中了
}else if(r == tr[pos].r){
int tmp = tr[pos].l;
tr[pos].l = l;
push_up(pos);
int u = add_pre(tmp, l-1, pos);
// [L,R] -> [L,l-1] [l,R]
dict.insert(make_pair(l-1, u));
}else{
// [L,R] -> [L, l-1], [l,r] , [r+1, R]
int tmpl = tr[pos].l, tmpr = tr[pos].r;
tr[pos].l = l;
tr[pos].r = r;
push_up(pos);
int u = add_pre(tmpl, l-1, pos);
int v = add_suc(r+1, tmpr, pos);
dict.find(tmpr)->second = v;
dict.insert(make_pair(l-1, u));
dict.insert(make_pair(r, pos));
}
}
int get_pre(int x){
splay(x, 0);
int u = tr[x].s[0];
while(tr[u].s[1]) u = tr[u].s[1];
return u;
}
int get_suc(int x){
splay(x, 0);
int u = tr[x].s[1];
while(tr[u].s[0]) u = tr[u].s[0];
return u;
}
void top(int x){
map<int, int> :: iterator it_x = dict.find(x);
if(it_x == dict.end()){
map<int, int> :: iterator it_lb = dict.lower_bound(x);
split_node(it_lb->second, x, x);
it_x = dict.find(x);
}else if(tr[it_x->second].l != tr[it_x->second].r){
split_node(it_x->second, x, x);
it_x = dict.find(x);
}
int u = it_x->second;
int xl = get_pre(u);
int xr = get_suc(u);
splay(xl, 0);
splay(xr, xl);
tr[xr].s[0] = 0;
push_up(xr);
push_up(xl);
tr[u].l = -123456;
// 更新 x 的位置
int v = insert(-INF+1,-INF+1);
splay(L, 0);
splay(v, L);
splay(R, v);
tr[v].l = tr[v].r = x;
it_x->second = v;
}
int query(int x){
map<int, int> :: iterator it_x = dict.find(x);
if(it_x == dict.end()){
map<int, int> :: iterator it_lb = dict.lower_bound(x);
split_node(it_lb->second, x, x);
it_x = dict.find(x);
}else if(tr[it_x->second].l != tr[it_x->second].r){
split_node(it_x->second, x, x);
it_x = dict.find(x);
}
int u = it_x->second;
splay(u, 0);
return tr[tr[u].s[0]].size;
}
int rank_x(int x){
++x;
int u = rt;
while(u){
if(tr[tr[u].s[0]].size >= x){
u = tr[u].s[0];
}else if(tr[tr[u].s[0]].size + tr[u].r - tr[u].l + 1 >= x){
x -= tr[tr[u].s[0]].size;
return tr[u].l + x - 1;
}else{
x -= tr[tr[u].s[0]].size + tr[u].r - tr[u].l + 1;
u = tr[u].s[1];
}
}
return -1;
}
void init(){
dict.clear();
for(int i = 1; i <= idx; ++i){
tr[i].clear();
}
rt = 0;
idx = 0;
}
int main(){
scanf("%d", &T);
for(int t = 1; t <= T; ++t){
printf("Case %d:\n", t);
init();
L = insert(-INF,-INF);
R = insert(INF, INF);
scanf("%d%d", &n, &m);
int base = insert(1, n);
dict.insert(make_pair(n, base));
while(m--){
scanf("%s%d", op, &num);
if(*op == 'T'){
top(num);
}else if(*op == 'R'){
printf("%d\n", rank_x(num));
}else{
printf("%d\n", query(num));
}
}
}
return 0;
}
---- suffer now and live the rest of your life as a champion ----