ODT 理论+代码
简介
ODT 用于维护一段颜色序列,支持 \(O(\log n)\) 实现区间覆盖颜色,查询单点颜色,\(O(1)\) 查询单点 pre
(左侧第一个相同颜色的位置)等。
ODT 将一段相同的颜色段存储为 set
里的一个节点。区间覆盖 \([l,r]\) 时,暴力处理并删除 \([l,r]\) 包含的所有颜色段,再插入 \([l,r]\)。若 \(l\) 或 \(r\) 落在原先某段区间的中间,就将这段区间按 \(l\) 或 \(r\) 劈成区间 \([l,r]\) 内的部分和 \([l,r]\) 外的部分,再处理即可。
时间复杂度分析
对于 ODT,我们有颜色段均摊的经典结论:
对一个长度为 \(n\) 的颜色序列进行 \(m\) 次区间赋值,向 set
中插入和删除节点的次数是均摊 \(O(n+m)\) 的。
这是因为,每次执行区间覆盖时,最多会新产生 \(3\) 个区间(\([l,r]\)、左右各劈出一个区间),而每个区间最多被删除一次。因此总操作次数是 \(O(n+m)\) 的。算上 set
,时间复杂度为 \(O\big((n+m)\log n\big)\)。
代码实现
核心:split
和 assign
函数。
split
set<Node>::iterator split(cint p);
操作:将 ODT 中 \(p\) 所在的颜色段 \([l,r]\) 分割成 \([l,p-1]\) 和 \([p,r]\);
参数:分割点 \(p\);
返回值:颜色段 \([p,r]\) 的迭代器;
实现:
set<Node>::iterator split(cint x) {
set<Node>::iterator it = --tr.upper_bound({x, 0, 0});
if(it->l == x) return it;
int l = it->l, r = it->r, c = it->c;
tr.erase(it);
tr.insert({l, x - 1, c});
tr.insert({x, r, c});
return tr.find({x, r, c});
}
assign
void assign(cint l, cint r, cint c);
操作:将 ODT 中 \([l,r]\) 覆盖为颜色 \(c\);
参数:\(l,r,c\);
实现:
void assign(cint l, cint r, cint c) {
set<Node>::iterator itr = split(r + 1), itl = split(l);
for(set<Node>::iterator it = itl; it != itr; ++it) {
// 对即将删掉的区间进行操作
}
tr.erase(itl, itr);
tr.insert({l, r, c});
}
注意点:要先 split(r + 1)
然后 split(l)
。如果反过来,可能在执行 \(split(r + 1)\) 时导致 itl
迭代器的失效。
带 pre 版本代码
namespace odt {
set<Node> pos[2 * N], tr;
int get_pre(cint x) {
if(x > n || x < 1) return -1;
set<Node>::iterator it = --tr.upper_bound({x, 0, 0});
if(x > it->l) return x - 1;
it = --pos[it->c].find(*it);
return it->r;
}
int get_suc(cint x) {
if(x > n || x < 1) return -1;
set<Node>::iterator it = --tr.upper_bound({x, 0, 0});
if(x < it->r) return x + 1;
it = ++pos[it->c].find(*it);
return it->l;
}
void build() {
for(int i = 1; i <= nn; i++) pos[i].insert({0, 0, i});
for(int i = 1; i <= nn; i++) pos[i].insert({n + 1, n + 1, i});
tr.insert({0, 0, 0});
tr.insert({n + 1, n + 1, 0});
for(int i = 2, j = 1; i <= n + 1; i++) {
if(a[i] != a[j]) {
tr.insert({j, i - 1, a[i - 1]});
pos[a[i - 1]].insert({j, i - 1, a[i - 1]});
j = i;
}
}
}
set<Node>::iterator split(cint x) {
set<Node>::iterator it = --tr.upper_bound({x, 0, 0});
if(it->l == x) return it;
int l = it->l, r = it->r, c = it->c;
pos[c].erase(*it);
tr.erase(it);
pos[c].insert({l, x - 1, c});
pos[c].insert({x, r, c});
tr.insert({l, x - 1, c});
tr.insert({x, r, c});
return tr.find({x, r, c});
}
void assign(cint l, cint r, cint c) {
set<Node>::iterator itr = split(r + 1), itl = split(l);
for(set<Node>::iterator it = itl; ++it != itr; ) {
add_modify(it->l, it->l - 1);
}
static vector<int> suc;
suc.clear();
for(set<Node>::iterator it = itl; it != itr; ++it) {
int tmp = get_suc(it->r);
if(tmp > r) suc.push_back(tmp);
}
for(set<Node>::iterator it = itl; it != itr; ++it) {
pos[it->c].erase(*it);
}
pos[c].insert({l, r, c});
tr.erase(itl, itr);
tr.insert({l, r, c});
for(int tmp : suc) {
add_modify(tmp, get_pre(tmp));
}
add_modify(l, get_pre(l));
add_modify(get_suc(r), r);
}
}