luogu P3369 【模板】普通平衡树(splay)

嘟嘟嘟


突然觉得splay挺有意思,唯一不足的是这几天是一天一道,debug到崩溃。


做了几道平衡树基础题后,对这题有莫名的自信,还算愉快的敲完了代码后,发现样例都过不去,然后就陷入了无限的debug环节了……算了,伤心的事就别再提了。


说一下这题怎么做:
1.插入
不说了

void insert(int x)
{
  int now = root, f = 0;
  while(now && t[now].val != x) f = now, now = t[now].ch[x > t[now].val];
  if(now) t[now].cnt++;
  else
    {
      now = ++ncnt;
      if(f) t[f].ch[x > t[f].val] = now;
      t[now].fa = f;
      t[now].ch[0] = t[now].ch[1] = 0;	
      t[now].siz = t[now].cnt = 1; t[now].val = x;
    }
  splay(now, 0);
}

2.删除
\(x\)的前驱\(a\),后继\(b\),把\(a\)旋到根,再把\(b\)旋到\(a\)的右儿子,这样\(b\)的左子树就只剩\(x\)唯一一个节点了。若有多个,\(cnt--\),否则清零。
然后一定要把\(b\)旋转到根。刚开始我很不理解,多亏了学姐说:不旋转怎么更新平衡树啊!我才知道旋转还有一个作用,就是更新这个节点的所有祖先的值。跟线段树回溯的时候更新祖先节点一样。
我刚开始写了一个\(clear\)函数,然后因为没传实参debug了半天……

void del(int x)
{
  int a = pre(x), b = nxt(x);		
  splay(a, 0); splay(b, a);
  int now = t[b].ch[0];
  if(t[now].cnt > 1) t[now].cnt--, splay(now, 0);
  else t[b].ch[0] = 0;
}

3.查询\(x\)数的排名(数据保证\(x\)存在)
有一个很棒的做法是查找\(x\)并把\(x\)旋到根,然后返回根的左子树的大小就行。
然而我刚开始是用\(bst\)的思路写的:如果\(x\)小于当前节点权值,到左子树找;如果等于,返回累加值\(ret\)加当前节点大小;否则,\(ret\)加上左子树和当前节点大小,然后去右子树找。
这个思路是没问题的,写起来也不难,然而我被某谷的第\(12\)个点卡掉了:数据是这样的:先添加\(50000\)个点,接下来\(50000\)个操作全是查找\(i\)\(50000\)的排名。按我的写法是虽然是稳定的\(O(n \log{n})\),但就是TLE;而换了我刚开始说的那个写法后,由于询问可能是\(O(1)\)可能是\(O(n \log{n})\)的,竟然迅速的AC了……
注释掉的是\(bst\)的写法

int queryKth(int x)
{
  find(x);
  return t[t[root].ch[0]].siz;
  /*int now = root, ret = 0;
  while(1)
    {
      if(t[now].val > x) now = t[now].ch[0];
      else if(t[now].val == x) return ret + t[t[now].ch[0]].siz;
      else ret += t[t[now].ch[0]].siz + t[now].cnt, now = t[now].ch[1];
      }*/
}

4.查询排名为\(x\)的数
\(bst\)的写法就行:如果\(x\)小于等于左子树大小,去左子树找;否则如果小于等于左子树加上当前节点的大小,返回当前节点权值;否则\(x\)减去左子树和节点大小,到右子树去找。

int queryX(int k)
{
  int now = root;
  while(1)
    {
      if(k <= t[t[now].ch[0]].siz) now = t[now].ch[0];
      else if(k <= t[t[now].ch[0]].siz + t[now].cnt) return t[now].val;
      else k -= (t[t[now].ch[0]].siz + t[now].cnt), now = t[now].ch[1];
    }
}

5,6.前驱,后继
不说了

void find(int x)
{
  int now = root;
  if(!now) return;
  while(t[now].val != x && t[now].ch[x > t[now].val]) now = t[now].ch[x > t[now].val];
  splay(now, 0);
}
int pre(int x)
{
  find(x);
  //_PrintTr(root);
  if(t[root].val < x) return root;
  int now = t[root].ch[0];
  while(t[now].ch[1]) now = t[now].ch[1];
  return now;
}
int nxt(int x)
{
  find(x);
  //_PrintTr(root);
  if(t[root].val > x) return root;
  int now = t[root].ch[1];
  while(t[now].ch[0]) now = t[now].ch[0];
  return now;
}

最后当然要放完整代码啦。(自认为挺短的)
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define rg register
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 1e5 + 5;
inline ll read()
{
  ll ans = 0;
  char ch = getchar(), last = ' ';
  while(!isdigit(ch)) {last = ch; ch = getchar();}
  while(isdigit(ch)) {ans = (ans << 1) + (ans << 3) + ch - '0'; ch = getchar();}
  if(last == '-') ans = -ans;
  return ans;
}
inline void write(ll x)
{
  if(x < 0) x = -x, putchar('-');
  if(x >= 10) write(x / 10);
  putchar(x % 10 + '0');
}

int n;
struct Tree
{
  int ch[2], fa;
  int siz, cnt, val;
}t[maxn];
int root, ncnt = 0;

void _PrintTr(int now)
{
  if(!now) return;
  printf("nd:%d val:%d ls:%d rs:%d\n", now, t[now].val, t[t[now].ch[0]].val, t[t[now].ch[1]].val);
  _PrintTr(t[now].ch[0]); _PrintTr(t[now].ch[1]);
}

void pushup(int now)
{
  t[now].siz = t[t[now].ch[0]].siz + t[t[now].ch[1]].siz + t[now].cnt;
}
void rotate(int x)
{
  int y = t[x].fa, z = t[y].fa, k = (t[y].ch[1] == x);
  t[z].ch[t[z].ch[1] == y] = x; t[x].fa = z;
  t[y].ch[k] = t[x].ch[k ^ 1]; t[t[x].ch[k ^ 1]].fa = y;
  t[x].ch[k ^ 1] = y; t[y].fa = x;
  pushup(y); pushup(x);
}
void splay(int x, int s)
{
  while(t[x].fa != s)
    {
      int y = t[x].fa, z = t[y].fa;
      if(z != s)
	{
	  if((t[z].ch[1] == y) ^ (t[y].ch[1] == x)) rotate(x);
	  else rotate(y);
	}
      rotate(x);
    }
  if(!s) root = x;
}
void insert(int x)
{
  int now = root, f = 0;
  while(now && t[now].val != x) f = now, now = t[now].ch[x > t[now].val];
  if(now) t[now].cnt++;
  else
    {
      now = ++ncnt;
      if(f) t[f].ch[x > t[f].val] = now;
      t[now].fa = f;
      t[now].ch[0] = t[now].ch[1] = 0;	
      t[now].siz = t[now].cnt = 1; t[now].val = x;
    }
  splay(now, 0);
}
void find(int x)
{
  int now = root;
  if(!now) return;
  while(t[now].val != x && t[now].ch[x > t[now].val]) now = t[now].ch[x > t[now].val];
  splay(now, 0);
}
int pre(int x)
{
  find(x);
  //_PrintTr(root);
  if(t[root].val < x) return root;
  int now = t[root].ch[0];
  while(t[now].ch[1]) now = t[now].ch[1];
  return now;
}
int nxt(int x)
{
  find(x);
  //_PrintTr(root);
  if(t[root].val > x) return root;
  int now = t[root].ch[1];
  while(t[now].ch[0]) now = t[now].ch[0];
  return now;
}
void del(int x)
{
  int a = pre(x), b = nxt(x);		
  splay(a, 0); splay(b, a);
  int now = t[b].ch[0];
  if(t[now].cnt > 1) t[now].cnt--, splay(now, 0);
  else t[b].ch[0] = 0;
}
int queryKth(int x)
{
  find(x);
  return t[t[root].ch[0]].siz;
  /*int now = root, ret = 0;
  while(1)
    {
      if(t[now].val > x) now = t[now].ch[0];
      else if(t[now].val == x) return ret + t[t[now].ch[0]].siz;
      else ret += t[t[now].ch[0]].siz + t[now].cnt, now = t[now].ch[1];
      }*/
}
int queryX(int k)
{
  int now = root;
  while(1)
    {
      if(k <= t[t[now].ch[0]].siz) now = t[now].ch[0];
      else if(k <= t[t[now].ch[0]].siz + t[now].cnt) return t[now].val;
      else k -= (t[t[now].ch[0]].siz + t[now].cnt), now = t[now].ch[1];
    }
}

int main()
{
  //freopen("test.in", "r", stdin);
  //freopen("ha.out", "w", stdout);
  insert(-INF); insert(INF);
  n = read();
  for(int i = 1; i <= n; ++i)	
    {
      int op = read(), x = read();
      if(op == 1) insert(x);
      else if(op == 2) del(x);
      else if(op == 3) write(queryKth(x)), enter;
      else if(op == 4) write(queryX(x + 1)), enter;
      else if(op == 5) write(t[pre(x)].val), enter;
      else write(t[nxt(x)].val), enter;
      //_PrintTr(root);
    }
  return 0;
}
posted @ 2018-12-03 10:25  mrclr  阅读(144)  评论(0编辑  收藏  举报