POJ 2892 Tunnel Warfare (SBT + stack)
题意:给定了初始的状态:有n个村庄连成一条直线,现在有三种操作: 1.摧毁一个村庄 2.询问某个村庄,输出与该村庄相连的村庄数量(包括自己) 3.修复被摧毁的村庄,优先修复最近被摧毁的..............
分析:用SBT做的话,摧毁村庄就插入,修复就移除,如果要询问的话:找到第一个大于等于该村庄编号和第一个小于等于该村庄编号的,等价于找到了联通在一起的村庄。
朴素的做法可以 set + stack + 二分 搞之.................
#include <iostream> #include <algorithm> #include <cmath> #include <cstdio> #include <cstdlib> #include <cstring> #include <string> #include <vector> #include <set> #include <queue> #include <stack> #include <climits>//形如INT_MAX一类的 #define MAX 55555 #define INF 0x7FFFFFFF #define REP(i,s,t) for(int i=(s);i<=(t);++i) #define ll long long #define mem(a,b) memset(a,b,sizeof(a)) #define mp(a,b) make_pair(a,b) #define L(x) x<<1 #define R(x) x<<1|1 # define eps 1e-5 //#pragma comment(linker, "/STACK:36777216") ///传说中的外挂 using namespace std; struct sbt { int l,r,s,key; } tr[MAX]; int top , root; void left_rot(int &x) { int y = tr[x].r; tr[x].r = tr[y].l; tr[y].l = x; tr[y].s = tr[x].s; //转上去的节点数量为先前此处节点的size tr[x].s = tr[tr[x].l].s + tr[tr[x].r].s + 1; x = y; } void right_rot(int &x) { int y = tr[x].l; tr[x].l = tr[y].r; tr[y].r = x; tr[y].s = tr[x].s; tr[x].s = tr[tr[x].l].s + tr[tr[x].r].s + 1; x = y; } void maintain(int &x,bool flag) { if(flag == 0) { //左边 if(tr[tr[tr[x].l].l].s > tr[tr[x].r].s) {//左孩子左子树size大于右孩子size right_rot(x); } else if(tr[tr[tr[x].l].r].s > tr[tr[x].r].s) {//左孩子右子树size大于右孩子size left_rot(tr[x].l); right_rot(x); } else return ; } else { //右边 if(tr[tr[tr[x].r].r].s > tr[tr[x].l].s) { //右孩子的右子树大于左孩子 left_rot(x); } else if(tr[tr[tr[x].r].l].s > tr[tr[x].l].s) { //右孩子的左子树大于左孩子 right_rot(tr[x].r); left_rot(x); } else return ; } maintain(tr[x].l,0); maintain(tr[x].r,1); } void insert(int &x,int key) { if(x == 0) { //空节点 x = ++ top; tr[x].l = tr[x].r = 0; tr[x].s = 1; tr[x].key = key; } else { tr[x].s ++; if(key < tr[x].key) insert(tr[x].l,key); else insert(tr[x].r,key); maintain(x,key >= tr[x].key); } } int remove(int &x,int key) { int k; tr[x].s --; if(key == tr[x].key || (key < tr[x].key && tr[x].l == 0) || (key > tr[x].key && tr[x].r == 0)) { k = tr[x].key; if(tr[x].l && tr[x].r) { tr[x].key = remove(tr[x].l,tr[x].key + 1); } else { x = tr[x].l + tr[x].r; } } else if(key > tr[x].key) { k = remove(tr[x].r,key); } else if(key < tr[x].key) { k = remove(tr[x].l,key); } return k; } int pred(int &x,int y,int key) //前驱 小于 { if(x == 0) return tr[y].key ; if(tr[x].key < key) return pred(tr[x].r,x,key); else if(tr[x].key > key) return pred(tr[x].l,y,key); else return key; }//pred(root,0,key) int succ(int &x,int y,int key) { //后继 大于 if(x == 0) return tr[y].key; if(tr[x].key > key) return succ(tr[x].l,x,key); else if(tr[x].key < key) return succ(tr[x].r,y,key); else return key; } int n,m; char c; int st[MAX]; int head = 0; int main() { root = 0; top = 0; int b; scanf("%d%d",&n,&m); for(int i=0; i<m; i++) { cin >> c; if(c == 'D') { scanf("%d",&b); st[head++] = b; insert(root,b); } if(c == 'R') { remove(root,st[--head]); } if(c == 'Q') { scanf("%d",&b); int pre = pred(root,0,b); int suc = succ(root,0,b); if(suc == 0) suc = n+1; if(pre == suc) { puts("0"); continue; } printf("%d\n",suc - pre - 1); } } return 0; }