题解 P4008 [NOI2003] 文本编辑器
题目描述
Link
您需要实现一个文本编辑器,需要支持一下六个操作:
- \(Move(k)\) ,将光标移到第 \(k\) 个字符之后。若 \(k=0\) ,则表示把光标移到文本开头。
- \(Insert(len ,s)\) ,在当前光标后插入一个长度为 \(len\) 的字符串 \(s\)。
- \(Delete(n)\) ,删除光标后的 \(n\) 个字符。
- \(Get(n)\) ,输出光标后的 \(n\) 个字符。
- \(Prev()\) ,将光标前移一个字符。
- \(Next()\) ,将光标后移一个字符。
在 \(Insert\) 操作中可能存在换行符,你需要忽略掉它们,但是保证所有的字符的 ASCII 码在 \([32 ,126]\) 内。
保证 \(Insert\) 操作插入的总长度不超过 \(2^{21}\) 个。
保证光标不会移到非法位置,保证删除,查询的字符存在。
Solution
发现有区间插入,区间删除,考虑用 \(Splay\) 维护。
用一个变量维护当前的光标位置,设这个变量叫做 \(pos\) 。
首先,为了防止越界,我们现在 \(Splay\) 里插入两个 \(\backslash0\) ,防止越界。
之所以选择 \(\backslash 0\) ,是因为把这个字符输出来跟没输出一样,对于 \(GET\) 操作就不需要特判了。
一个一个操作看:
- \(Build\)
没有这个操作!
我们先看看怎么把一个序列 \(a\) 建成一棵 \(Splay\) 。
类似线段树的建树方法,设当前节点在序列中所对应的位置是 \([l,r]\) ,设 \(mid = (l+r)/2\) 。
我们让这个节点的值为 $a[mid] $ ,然后左儿子对应的位置是 \([l,mid - 1]\) ,右儿子对应的位置就是 \([mid+1,r]\) 。
递归建树即可。
这样建出来的树是平衡,高度为 \(\mathcal{O}(\log_2 n)\) 。
- \(Move\)
直接让 \(pos=k\) 即可。
- \(Insert\)
我们在 \(Splay\) 中找到排名为 \(pos+1\) 的节点(前面插入了一个空节点,所以要 \(+1\) ),设这个节点为 \(ls\) 。
同时找到排名为 \(pos+2\) 的节点,设这个节点为 \(rs\) 。
把 \(ls\) 旋转到根,把 \(rs\) 旋转到 \(ls\) 下面,然后找到 \(rs\) 的左儿子,在这个左儿子下把这个插入的串用前面的 \(build\) 弄出来就行。
记得 \(pushup\) 。
- \(Delete\)
在 \(Splay\) 中找到排名为 \(pos+1\) 的节点(要删除的区间开头是 \(pos+1\) ,因为 \(Splay\) 要找的是前一个节点,但是又加上了一个空节点,所以位置还是 \(pos+1\) ),记为 \(ls\) 。
再找到排名为 \(pos+len+2\) 的节点(要删除的区间结尾是 \(pos+len\) ,前面有一个空节点,然后又要找后一个节点,所以是 \(pos+len+2\)) ,记为 \(rs\) 。
把 \(ls\) 旋转到根,把 \(rs\) 旋转到 \(ls\) 下面,直接删除 \(rs\) 的左儿子即可。
- \(Get\)
和上面大同小异,只不过不是删除 \(rs\) 的左儿子,而是输出 \(rs\) 左儿子的中序遍历。
- \(Prev\)
让 \(pos \to pos - 1\) 。
- \(Next\)
让 \(pos\to pos + 1\) 。
思路倒是挺简单,但是代码有一点点难写。
第一次一遍过。
#include <cstdio>
#include <cstring>
#include <cctype>
inline int read() {
int num = 0 ,f = 1; char c = getchar();
while (!isdigit(c)) f = c == '-' ? -1 : f ,c = getchar();
while (isdigit(c)) num = (num << 1) + (num << 3) + (c ^ 48) ,c = getchar();
return num * f;
}
inline int min(int a ,int b) {return a < b ? a : b;}
inline int max(int a ,int b) {return a > b ? a : b;}
inline void swap(int &a ,int &b) {int t = a; a = b; b = t;}
const int N = 1 << 22;
struct Splay {
struct node {
int ch[2] ,fa ,son ,size;
char val;
node () : fa(0) ,son(0) ,size(0) ,val(0) {}
}t[N]; int root ,tot;
Splay () : root(0) ,tot(0) {}
inline void update(int now) {
t[now].size = t[t[now].ch[0]].size + t[t[now].ch[1]].size + 1;
}
inline void rotate(int x) {
int y = t[x].fa ,z = t[y].fa ,k = t[y].ch[1] == x;
t[x].fa = z; t[z].ch[t[z].ch[1] == y] = x;
t[y].ch[k] = t[x].ch[k ^ 1]; t[t[x].ch[k ^ 1]].fa = y;
t[x].ch[k ^ 1] = y; t[y].fa = x;
update(y); update(x);
}
inline void splay(int x ,int goal) {
while (t[x].fa != goal) {
int y = t[x].fa ,z = t[y].fa;
if (z != goal) {
if ((t[z].ch[1] == y) == (t[y].ch[1] == x)) rotate(y);
else rotate(x);
}
rotate(x);
}
if (goal == 0) root = x;
}
inline int findval(int rank) {
int now = root;
while (now) {
int l = t[now].ch[0];
if (t[l].size >= rank) now = l;
else if (t[l].size + 1 >= rank) return now;
else rank -= t[l].size + 1 ,now = t[now].ch[1];
}
return 0;
}
inline void build(int &now ,int fa ,int l ,int r ,char *s) {
if (l > r) return ;
int mid = (l + r) >> 1;
now = ++tot;
t[now].val = s[mid];
t[now].fa = fa;
build(t[now].ch[0] ,now ,l ,mid - 1 ,s);
build(t[now].ch[1] ,now ,mid + 1 ,r ,s);
update(now);
}
inline void insert(int st ,int len ,char *s) {
int l = findval(st + 1) ,r = findval(st + 2);
splay(l ,0); splay(r ,l);
build(t[r].ch[0] ,r ,1 ,len ,s); update(r); update(l);
}
inline void remove(int a ,int b) {
int l = findval(a) ,r = findval(b + 2);
splay(l ,0); splay(r ,l);
int now = t[r].ch[0];
t[now].fa = t[r].ch[0] = 0; update(r); update(l);
}
inline void dfs(int now) {
if (now == 0) return ;
if (t[now].ch[0]) dfs(t[now].ch[0]);
putchar(t[now].val);
if (t[now].ch[1]) dfs(t[now].ch[1]);
}
inline void print(int a ,int b) {
int l = findval(a) ,r = findval(b + 2);
splay(l ,0); splay(r ,l);
int now = t[r].ch[0];
dfs(now); puts(""); //记得输出一个换行
}
}t;
char s[N] ,opt[10];
int n ,pos;
signed main() {
n = read();
s[0] = 0; s[1] = 0;
t.build(t.root ,0 ,1 ,2 ,s);
while (n--) {
scanf("%s" ,opt);
if (opt[0] == 'M') pos = read();
else if (opt[0] == 'I') {
int len = read();
//这里输入有点麻烦,需要把不在 [32,126] 之间的字符过滤掉
for (int i = 1; i <= len; i++) {
char c = getchar();
while (c < 32 || c > 126) c = getchar();
s[i] = c;
}
t.insert(pos ,len ,s);
}
else if (opt[0] == 'D') {
int len = read(); t.remove(pos + 1 ,pos + len);
}
else if (opt[0] =='G') {
int len = read(); t.print(pos + 1 ,pos + len);
}
else if (opt[0] == 'P') pos--;
else if (opt[0] == 'N') pos++;
}
return 0;
}