【Splay】BZOJ 1014-prefix

最近几天写SAM写的烦了...就去找神犇要了一个数据结构题来写。

 

好吧,虽然看起来不是很难,思维难度中等,但是代码复杂度还是有点高的,写了半个下午+晚上一点。

 

总的来说算法有三个:Splay维护hash值(用于字符串比较),二分LCP长度。

不会字符串hash和Splay的请先去点前置技能~

 

将整个字符串看做一个区间,用Splay可以维护这个区间的信息并且对区间进行修改,插入,删除等操作(不过本题不用删除)。

定义rch[i]为i的右孩子,lch[i]为i的左孩子,key[i]为i节点代表的区间的hash值,size[i]为i的子树大小。很容易看出

  key[i] = key[lch[i]] * 27size[rch[i]]+1 + ch[i] * 27size[rch[i]] + key[lch[i]];

 

分开来看看三个过程:

  R:找到区间第k大,直接修改然后一路向上更新维护hash值。

  I :将a号点转到根,a+1号点转到根的右孩子,那么根据平衡树的性质,a+1号点是没有左孩子的。这时候将新的点直接插到左孩子上,向上更新。

  Q:二分出一个长度len,然后从Splay中取出区间[a,a+len-1]的hash值和[b,b+len-1]的hash值,直接比较就行了。合法加大长度,非法减少长度。

 

细节有点多,写出来有点蛋疼的150 line

  1 #include <cstdio>
  2 #include <iostream>
  3 #include <cstring>
  4 #include <iostream>
  5 #define lch t[t[x].s[0]]
  6 #define rch t[t[x].s[1]]
  7 #define fa  t[t[x].f]
  8 #define maxn 155000
  9 using namespace std;
 10 
 11 typedef unsigned int uni;
 12 
 13 struct node { int s[2],key,f,size; uni num,ch; } t[maxn];
 14 int i,j,n,m,k,cnt,LEN,root;
 15 char S[maxn],CC,ch;
 16 uni _27[maxn];
 17 
 18 void update(int x)
 19 {
 20     t[x].size = lch.size + rch.size + 1;
 21     lch.f = x; rch.f = x;
 22     int len = rch.size;
 23     t[x].num = lch.num * _27[len+1] + t[x].ch * _27[len] + rch.num;
 24 }
 25 
 26 void clear(int x) { t[x].s[0] = t[x].s[1] = t[x].f = t[x].num = t[x].ch = t[x].key = t[x].size = 0; }
 27 
 28 void Rou(int x,bool kind) // kind == 0 : L      kind == 1 : R;
 29 {
 30     int v = t[x].s[!kind]; t[x].s[!kind] = t[v].s[kind]; t[v].s[kind] = x;
 31     t[v].f = t[x].f; t[x].f = v; t[t[x].s[!kind]].f = x;
 32 
 33     t[t[v].f].s[x == t[t[v].f].s[1]] = v;
 34     update(x); update(v);
 35 
 36     clear(0);
 37 }
 38 
 39 #define rotate(a) Rou(t[(a)].f,t[t[(a)].f].s[0] == (a))
 40 
 41 void splay(int x,int f) //将x转到父亲为f
 42 {
 43     if (f == 0) root = x;
 44     while (t[x].f != f)
 45     {
 46         int y = t[x].f, z = t[y].f;
 47         int k = t[z].s[1] == y; 
 48         if (z == f) rotate(x);
 49         else 
 50             if (t[y].s[k] == x)
 51                 rotate(y),rotate(x);
 52             else rotate(x),rotate(x);
 53     }
 54 }
 55 
 56 int find(int x,int k) // 查询第k大【字符串的第k位】
 57 {
 58     if (lch.size == k-1) return x;
 59     if (lch.size > k-1) return find(t[x].s[0],k);
 60     else return find(t[x].s[1],k-1-lch.size);
 61 }
 62 
 63 void build(int l,int r,int x)
 64 {
 65     int mid = (l + r) >> 1;
 66     t[x].ch = S[mid]-'a'+1; t[x].size = 1;
 67     t[x].key = mid;
 68 
 69     if (mid > l) build(l,mid-1,t[x].s[0] = ++cnt);
 70     if (mid < r) build(mid+1,r,t[x].s[1] = ++cnt);
 71 
 72     update(x);
 73 }
 74 
 75 void workR()
 76 {
 77     scanf("%d %c\n",&k,&CC);
 78     int x = find(root,k+1);
 79     t[x].ch = CC-'a'+1;
 80     while (x != 0)
 81         update(x),x = t[x].f;
 82 }
 83 
 84 void workI()
 85 {
 86     scanf("%d %c\n",&k,&CC);
 87     int y = find(root,k+1);
 88     splay(y,0);
 89     int x = find(root,k+2);
 90     splay(x,y);
 91 
 92     t[x].s[0] = ++cnt; t[cnt].ch = CC-'a'+1; t[cnt].f = x; x = cnt;
 93     while (x != 0)
 94         update(x),x = t[x].f;
 95     splay(cnt,0);
 96 }
 97 
 98 uni get_num(int sta,int len)
 99 {
100     int y = find(root,sta);
101     splay(y,0);
102     int x = find(root,sta+len+1);
103     splay(x,y);
104 
105     return lch.num;
106 }
107 
108 void workQ()
109 {
110     int a,b;
111     scanf("%d %d\n",&a,&b);
112     if (a > b) swap(a,b);
113     int mid,l = 0,r = cnt-b;
114     while (l < r)
115     {
116         mid = (l + r) >> 1;
117         uni t1 = get_num(a,mid), t2 = get_num(b,mid);
118         if (t1 == t2) l = mid+1;
119         else r = mid;
120     }
121     printf("%d\n",l-1);
122 }
123 
124 int main()
125 {
126     _27[0] = 1;
127     for (i = 1; i <= 100000; i++) _27[i] = 27*_27[i-1];
128 
129     scanf("%s\n",S+1); LEN = strlen(S+1); cnt = 1;
130     build(1,LEN,1); root = 1;
131 
132     int x = root; while (t[x].s[0] != 0) t[x].size++,x = t[x].s[0]; 
133     t[x].size++; t[x].s[0] = ++cnt; t[cnt].size = 1; t[cnt].f = x;
134 
135     x = root; while (t[x].s[1] != 0) t[x].size++,x = t[x].s[1]; 
136     t[x].size++; t[x].s[1] = ++cnt; t[cnt].size = 1; t[cnt].f = x;
137 
138     scanf("%d\n",&m);
139     while (m--)
140     {
141         scanf("%c",&ch);
142         if (ch == 'Q') workQ();
143         if (ch == 'R') workR();
144         if (ch == 'I') workI();
145     }
146 
147     return 0;
148 }
View Code

 

posted on 2015-01-13 15:44  MMMoonLighttt  阅读(145)  评论(1编辑  收藏  举报

导航