数据结构-Old Driver Tree(Chtholly Tree)
学习笔记-珂朵莉树
ODT是不知何时火起来的一种暴力数据结构,这玩意是一种基于stl的set的一种奇怪的东西,叫它珂朵莉树是因为它的暴力美学,也因为它的出处 CF896C
话不多说,进入正题,这玩意我是看着SJR大佬的博客,在WZY和MXY大佬的答疑下弄会的。因为我是在太弱了,学这玩意之前连重载运算符和构造函数都不会用,并且set也没用过。。。
存储原理
首先定义一种结构体,
1 struct node 2 { 3 int l,r; 4 mutable LL v;//mutable以便随时在set里改变它的值 5 node(int L, int R=-1, LL V=0):l(L), r(R), v(V) {}//构造函数,便于快速添加结构体 6 bool operator<(const node& o) const//重载运算符,以左端为标准排列在set里面 7 { 8 return l < o.l; 9 } 10 };
这样的一个结构体表示在数列的[l, r]之间全都是v,然后用set存储,这样可以有顺序的一段一段存储序列每一段的数值。
1 set <node> s;
这样子就ok了
操作(一):分裂split
这个操作可以把一段区间咔吧一下变成两个区间,值不变,相当于对整个数列什么也没做,但是会很方便后面的操作。
1 #define IT set<node>::iterator 2 IT split(int pos)//这个操作会把区间[l, r]从pos分开成[l, pos - 1][pos, r]两部分 3 { 4 IT it = s.lower_bound(node(pos));//因为区间按照左端排序,所以可以利用lowerbound快速找到左端点大于等于pos的第一个区间 5 if (it != s.end() && it->l == pos) return it;//如果左端点正好是pos那就不用split了 6 --it;//如果是大于就操作前一个区间 7 int L = it->l, R = it->r; 8 LL V = it->v;//存储一下区间的两端和值 9 s.erase(it);//扔掉这个区间 10 s.insert(node(L, pos-1, V));//添加左区间 11 return s.insert(node(pos, R, V)).first;//添加右区间并返回它的迭代器,insert函数的first是所加入的区间的迭代器 12 }
这个函数大体思路:找到区间,特判,记录值,扔掉区间,加入区间,返回迭代器,注意pos在右区间内(划重点)
操作(二):区间赋值ass♂ign
有了split函数,我们就可以把需要区间赋值两端点切开,把中间的区间全部扔掉,然后加入一个新区间,由于这个操作会减少set中的元素个数,所以就可以保证复杂度是对的,代码如下
1 void assign(int l, int r, LL val=0) 2 { 3 IT itl = split(l),itr = split(r+1);//记录左右端 4 s.erase(itl, itr);//扔掉[l, r)的所有东西 5 s.insert(node(l, r, val));//把新的搞进来 6 }
其他操作:直接按照存储原理暴力
区间加:两端切开,切下来的所有区间暴力加
1 void add(int l, int r, LL val=1) 2 { 3 IT itl = split(l),itr = split(r+1); 4 for (; itl != itr; ++itl) itl->v += val; 5 }
区间查询第k小:切出来暴力排序
1 LL rank(int l, int r, int k) 2 { 3 vector<pair<LL, int> > vp; 4 IT itl = split(l),itr = split(r+1); 5 vp.clear(); 6 for (; itl != itr; ++itl) 7 vp.push_back(pair<LL,int>(itl->v, itl->r - itl->l + 1)); 8 sort(vp.begin(), vp.end()); 9 for (vector<pair<LL,int> >::iterator it=vp.begin();it!=vp.end();++it) 10 { 11 k -= it->second; 12 if (k <= 0) return it->first; 13 } 14 return -1LL; 15 }
区间幂和:快速幂,乘以段长,期间不要忘记膜,不然就naive了
1 LL sum(int l, int r, int ex, int mod) 2 { 3 IT itl = split(l),itr = split(r+1); 4 LL res = 0; 5 for (; itl != itr; ++itl) 6 res = (res + (LL)(itl->r - itl->l + 1) * pow(itl->v, LL(ex), LL(mod))) % mod; 7 return res; 8 }
代码实现
由于我太弱,总写炸,就直接把标程弄进来了
1 #include<cstdio> 2 #include<set> 3 #include<vector> 4 #include<utility> 5 #include<algorithm> 6 #define IT set<node>::iterator 7 8 using std::set; 9 using std::vector; 10 using std::pair; 11 12 typedef long long LL; 13 const int MOD7 = 1e9 + 7; 14 const int MOD9 = 1e9 + 9; 15 const int imax_n = 1e5 + 7; 16 17 LL pow(LL a, LL b, LL mod) 18 { 19 LL res = 1; 20 LL ans = a % mod; 21 while (b) 22 { 23 if (b&1) res = res * ans % mod; 24 ans = ans * ans % mod; 25 b>>=1; 26 } 27 return res; 28 } 29 30 struct node 31 { 32 int l,r; 33 mutable LL v; 34 node(int L, int R=-1, LL V=0):l(L), r(R), v(V) {} 35 bool operator<(const node& o) const 36 { 37 return l < o.l; 38 } 39 }; 40 41 set<node> s; 42 43 IT split(int pos) 44 { 45 IT it = s.lower_bound(node(pos)); 46 if (it != s.end() && it->l == pos) return it; 47 --it; 48 int L = it->l, R = it->r; 49 LL V = it->v; 50 s.erase(it); 51 s.insert(node(L, pos-1, V)); 52 return s.insert(node(pos, R, V)).first; 53 } 54 55 void add(int l, int r, LL val=1) 56 { 57 IT itl = split(l),itr = split(r+1); 58 for (; itl != itr; ++itl) itl->v += val; 59 } 60 61 void assign_val(int l, int r, LL val=0) 62 { 63 IT itl = split(l),itr = split(r+1); 64 s.erase(itl, itr); 65 s.insert(node(l, r, val)); 66 } 67 68 LL rank(int l, int r, int k) 69 { 70 vector<pair<LL, int> > vp; 71 IT itl = split(l),itr = split(r+1); 72 vp.clear(); 73 for (; itl != itr; ++itl) 74 vp.push_back(pair<LL,int>(itl->v, itl->r - itl->l + 1)); 75 std::sort(vp.begin(), vp.end()); 76 for (vector<pair<LL,int> >::iterator it=vp.begin();it!=vp.end();++it) 77 { 78 k -= it->second; 79 if (k <= 0) return it->first; 80 } 81 return -1LL; 82 } 83 84 LL sum(int l, int r, int ex, int mod) 85 { 86 IT itl = split(l),itr = split(r+1); 87 LL res = 0; 88 for (; itl != itr; ++itl) 89 res = (res + (LL)(itl->r - itl->l + 1) * pow(itl->v, LL(ex), LL(mod))) % mod; 90 return res; 91 } 92 93 int n, m; 94 LL seed, vmax; 95 96 LL rnd() 97 { 98 LL ret = seed; 99 seed = (seed * 7 + 13) % MOD7; 100 return ret; 101 } 102 103 LL a[imax_n]; 104 105 int main() 106 { 107 scanf("%d %d %lld %lld",&n,&m,&seed,&vmax); 108 for (int i=1; i<=n; ++i) 109 { 110 a[i] = (rnd() % vmax) + 1; 111 s.insert(node(i,i,a[i])); 112 } 113 s.insert(node(n+1, n+1, 0)); 114 int lines = 0; 115 for (int i =1; i <= m; ++i) 116 { 117 int op = int(rnd() % 4) + 1; 118 int l = int(rnd() % n) + 1; 119 int r = int(rnd() % n) + 1; 120 if (l > r) 121 std::swap(l,r); 122 int x, y; 123 if (op == 3) 124 x = int(rnd() % (r-l+1)) + 1; 125 else 126 x = int(rnd() % vmax) +1; 127 if (op == 4) 128 y = int(rnd() % vmax) + 1; 129 if (op == 1) 130 add(l, r, LL(x)); 131 else if (op == 2) 132 assign_val(l, r, LL(x)); 133 else if (op == 3) 134 printf("%lld\n",rank(l, r, x)); 135 else 136 printf("%lld\n",sum(l, r, x, y)); 137 } 138 return 0; 139 }