神器! 线性变换线段树 —— 线段树入门及应用
在写这篇博客之前首先申明一下,这个线段树的名字是我的好队友 shu_mj 取的。由于实在很好用,对于新手容易上手所以写一篇博客造福老百姓。
然后说明一下这仅仅是一篇关于线段树入门的文章。
单点更新:
单点更新不需要标记整个线段树只需要自底向上的维护值(最小值,最大值,和等)就行了。实现难度较低,新手也能很快学会。
在线段树中每个节点的含义都是一个线段根节点表示的是1-n的线段,假设根节点标号是o,他的左子树标号为o<<1,右子树标号为o<<1|1相应的左子树表示的区间为1-m(m=(l+r)>>1),而右子树表示的区间为m+1-r
这样我们用一颗高度为logn 的树即可表示整条线段了。
对于单点更新的线段树每个标记都能下放到最底层,由于每次下放标记最多沿着一路走所以每次操作的复杂度为logn的。
对于求和的时候我们把区间不断的拆分用线段树上的小区间(节点表示的区间)来拼成需要求和的完整区间,最后将这些节点上的标记累加起来即为答案。复杂度同样为logn的。
代码如下:
1 int num[LEN], sum[4*LEN], n; 2 3 void pushup(int o){ 4 sum[o] = sum[o<<1] + sum[o<<1|1]; 5 } 6 7 void build(int l, int r, int o){ 8 if(l == r){ 9 //设初值 10 return ; 11 } 12 int m = (l + r)/2; 13 build(l, m, o<<1); 14 build(m+1, r, o<<1|1); 15 pushup(o); 16 } 17 18 void update(int l, int r, int o, int pos, int val){ 19 if(l == r) { 20 sum[o] += val; 21 return ; 22 } 23 int m = (l + r)/2; 24 if(pos <= m) update(l, m, o<<1, pos, val); 25 else update(m+1, r, o<<1|1, pos, val); 26 pushup(o); 27 } 28 29 int query(int l, int r, int o, int L, int R){ 30 if(l >= L && r <= R) return sum[o]; 31 int m = (l + r) / 2, ret = 0; 32 if(L <= m) ret += query(l, m, o<<1, L, R); 33 if(R > m) ret += query(m+1, r, o<<1|1, L, R); 34 return ret; 35 }
区间更新:
区间更新相比较单点更新来说就要难很多了。首先我们要设计一个懒惰标记,为什么要设计这个标记呢?
我们若是在每次更新线段树的时候都更新至最底层那么复杂度显然会退化成线性的,因此我们给每个节点上都添加了一个懒惰标记若是不访问以这个节点为根的子树,这个标记就停留在这个节点。等到需要访问以这个节点为根的子树的时候在把标记pushdown到他的左右孩子上。这样复杂度就维持了nlogn了。在统计的时候和单点更新类似,只是在访问孩子节点之前需要把懒惰标记下放。
线性变换线段树:
所谓线性变换,就是我们的线段树能处理ax+b这种操作。现在线段树的常用操作有add,mul,set无论是哪种我们都可以使用线性变换得到。
add v: 1*x+v mul v:v*x+0 set v: 0*x+v
怎么样很方便吧。
这样一来我们就可以将这份线段树存为模板了,足以应付一般的线段树题目了。线性变换线段树由于维护了多余的信息多以时间效率可能不那么令人满意。
代码如下:
1 struct SegmentTree { 2 ll _sum[4*LEN], _mul[4*LEN], _add[4*LEN]; 3 SegmentTree(){} 4 5 void Display(int l, int r, int o, int d){ 6 if(l == r){ 7 for(int i=0; i<d; i++) cout << "\t"; 8 cout << _sum[o] << "(" << l << ", " << r << ")" << endl; 9 return ; 10 } 11 int m = (l + r) / 2; 12 Display(l, m, o<<1, d+1); 13 for(int i=0; i<d; i++) cout << "\t"; 14 cout << _sum[o] << "(" << l << ", " << r << ")" << endl; 15 Display(m+1, r, o<<1|1, d+1); 16 } 17 18 void init(){ 19 for(int i=0; i<4*LEN; i++){ 20 _mul[i] = 1; _add[i] = 0; 21 } 22 } 23 24 void pushup(int o){ 25 _sum[o] = _sum[o<<1] + _sum[o<<1|1]; 26 } 27 28 void push(int o, int l, int r, ll adval, ll muval){ 29 _sum[o] *= muval; 30 _mul[o] *= muval; 31 _add[o] *= muval; 32 33 _sum[o] += (r-l+1) * adval; 34 _add[o] += adval; 35 } 36 37 void pushdown(int o, int l, int r){ 38 int m = (l + r) / 2; 39 push(o << 1, l, m, _add[o], _mul[o]); 40 push(o << 1 | 1, m+1, r, _add[o], _mul[o]); 41 _add[o] = 0; _mul[o] = 1; 42 } 43 44 void build(int l, int r, int o){ 45 if(l == r){ 46 scanf("%lld", &_sum[o]); 47 return ; 48 } 49 int m = (l + r) / 2; 50 build(l, m, o<<1); 51 build(m+1, r, o<<1|1); 52 pushup(o); 53 } 54 55 void update(int l, int r, int o, int L, int R, ll adval, ll muval){ 56 if(L <= l && R >= r){ 57 push(o, l, r, adval, muval); 58 return ; 59 } 60 pushdown(o, l, r); 61 int m = (l + r) / 2; 62 if(L <= m) update(l, m, o<<1, L, R, adval, muval); 63 if(R > m) update(m+1, r, o<<1|1, L, R, adval, muval); 64 pushup(o); 65 } 66 67 ll query(int l, int r, int o, int L, int R){ 68 if(L <= l && R >= r){ 69 return _sum[o]; 70 } 71 pushdown(o, l, r); 72 int m = (l + r) / 2; 73 ll ret = 0; 74 if(L <= m) ret += query(l, m, o<<1, L, R); 75 if(R > m) ret += query(m+1, r, o<<1|1, L, R); 76 return ret; 77 } 78 };
最后,我用了这一份模板ac了一道2013杭州邀请赛的纯线段树的题目在以前不用这份模板的时候我写了5600K的代码再ac,而使用了模板之后瞬间代码量减小2K+
代码如下:
1 #include <cstdio> 2 #include <algorithm> 3 #define MP(a, b) make_pair(a, b) 4 using namespace std; 5 typedef pair<int, int> pii; 6 typedef pair<int, pii> piii; 7 const int MOD = 10007; 8 const int LEN = 100000+10; 9 10 void getsum(piii &a, piii b){ 11 a.first += b.first; a.first %= MOD; 12 a.second.first += b.second.first; a.second.first %= MOD; 13 a.second.second += b.second.second; a.second.second %= MOD; 14 } 15 16 struct SegmentTree { 17 int _sum[4*LEN], _mul[4*LEN], _add[4*LEN], _sum2[4*LEN], _sum3[4*LEN]; 18 19 void init(){ 20 for(int i=0; i<4*LEN; i++){ 21 _mul[i] = 1; _add[i] = 0; 22 _sum[i] = _sum2[i] = _sum3[i] = 0; 23 } 24 } 25 26 void pushup(int o){ 27 _sum[o] = (_sum[o<<1] + _sum[o<<1|1]) % MOD; 28 _sum2[o] = (_sum2[o<<1] + _sum2[o<<1|1]) % MOD; 29 _sum3[o] = (_sum3[o<<1] + _sum3[o<<1|1]) % MOD; 30 } 31 32 void push(int o, int l, int r, int adval, int muval){ 33 _sum3[o] *= (muval * muval % MOD * muval % MOD); 34 _sum2[o] *= (muval * muval % MOD); 35 _sum[o] *= muval; 36 _sum[o] %= MOD; _sum2[o] %= MOD; _sum3[o] %= MOD; 37 _mul[o] *= muval; 38 _add[o] *= muval; 39 _mul[o] %= MOD; _add[o] %= MOD; 40 41 _sum3[o] += (3*adval%MOD*_sum2[o]%MOD + 3*adval%MOD*adval%MOD*_sum[o]%MOD); 42 _sum3[o] += ((r-l+1)*adval%MOD*adval%MOD*adval%MOD); 43 _sum2[o] += (2*adval%MOD*_sum[o]%MOD + (r-l+1)*adval%MOD*adval%MOD); 44 _sum[o] += (r-l+1) * adval; 45 _sum3[o] %= MOD; _sum2[o] %= MOD; _sum[o] %= MOD; 46 _add[o] += adval; _add[o] %= MOD; 47 } 48 49 void pushdown(int o, int l, int r){ 50 int m = (l + r) / 2; 51 push(o << 1, l, m, _add[o], _mul[o]); 52 push(o << 1 | 1, m+1, r, _add[o], _mul[o]); 53 _add[o] = 0; _mul[o] = 1; 54 } 55 56 void build(int l, int r, int o){ 57 if(l == r){ 58 _sum[o] = _sum2[o] = _sum3[o] = 0; 59 return ; 60 } 61 int m = (l + r) / 2; 62 build(l, m, o<<1); 63 build(m+1, r, o<<1|1); 64 pushup(o); 65 } 66 67 void update(int l, int r, int o, int L, int R, int adval, int muval){ 68 if(L <= l && R >= r){ 69 push(o, l, r, adval, muval); 70 return ; 71 } 72 pushdown(o, l, r); 73 int m = (l + r) / 2; 74 if(L <= m) update(l, m, o<<1, L, R, adval, muval); 75 if(R > m) update(m+1, r, o<<1|1, L, R, adval, muval); 76 pushup(o); 77 } 78 79 piii query(int l, int r, int o, int L, int R){ 80 if(L <= l && R >= r){ 81 return MP(_sum[o], MP(_sum2[o], _sum3[o])); 82 } 83 pushdown(o, l, r); 84 int m = (l + r) / 2; 85 piii ret = MP(0, MP(0, 0)); 86 if(L <= m) getsum(ret, query(l, m, o<<1, L, R)); 87 if(R > m) getsum(ret, query(m+1, r, o<<1|1, L, R)); 88 return ret; 89 } 90 }; 91 92 SegmentTree sg; 93 94 int main() 95 { 96 int n, m, op, L, R, val; 97 while(scanf("%d%d", &n, &m) != EOF){ 98 if(!n && !m) break; 99 sg.init(); sg.build(1, n, 1); 100 for(int i=0; i<m; i++){ 101 scanf("%d%d%d%d", &op, &L, &R, &val); 102 if(op == 1){ 103 sg.update(1, n, 1, L, R, val, 1); 104 }else if(op == 2){ 105 sg.update(1, n, 1, L, R, 0, val); 106 }else if(op == 3){ 107 sg.update(1, n, 1, L, R, val, 0); 108 }else { 109 piii ans = sg.query(1, n, 1, L, R); 110 if(val == 1) printf("%d\n", ans.first); 111 else if(val == 2) printf("%d\n", ans.second.first); 112 else printf("%d\n", ans.second.second); 113 } 114 } 115 } 116 return 0; 117 }