线段树(四)——两个标记(add和set)
add无序,set有序。规定同时有两个标记时,表示先执行set再执行add。
1. 更新操作:
1 int op,cl,cr,v; 2 void update(int o, int L, int R) { 3 int lc = o*2, rc = o*2+1; 4 if(cl <= L && cr >= R) { // 标记修改 5 if(op == 2) addv[o] += v; 6 else { setv[o] = v; addv[o] = 0; } 7 } else { 8 pushdown(o); 9 int M = L + (R-L)/2; 10 if(cl <= M) update(lc, L, M); else maintain(lc, L, M); 11 if(cr > M) update(rc, M+1, R); else maintain(rc, M+1, R); 12 } 13 maintain(o, L, R); 14 }
此操作中需要维护标记,这里保证了不会出现先有add再有set,这种情况只会保留set。
值得注意的是,标记下推时左右子树都需要维护,其中递归进入的子树会在递归结束时自然调用maintain函数,而另一个子树需要手动调用maintain。
2. 标记传递:
1 void pushdown(int o) { 2 int lc = o*2, rc = o*2+1; 3 if(setv[o] >= 0) { 4 setv[lc] = setv[rc] = setv[o]; 5 addv[lc] = addv[rc] = 0; 6 setv[o] = -1; // 清除本结点标记 7 } 8 if(addv[o]) { 9 addv[lc] += addv[o]; 10 addv[rc] += addv[o]; 11 addv[o] = 0; // 清除本结点标记 12 } 13 }
set和add分别讨论,结点o有标记时才下推。
3. 维护信息:
1 void maintain(int o, int L, int R) { 2 int lc = o*2, rc = o*2+1; 3 if(R > L) { 4 sumv[o] = sumv[lc] + sumv[rc]; 5 minv[o] = min(minv[lc], minv[rc]); 6 maxv[o] = max(maxv[lc], maxv[rc]); 7 } 8 if(setv[o] >= 0) { minv[o] = maxv[o] = setv[o]; sumv[o] = setv[o] * (R-L+1); } 9 if(addv[o]) { minv[o] += addv[o]; maxv[o] += addv[o]; sumv[o] += addv[o] * (R-L+1); } 10 }
注意:所有叶子上总是保留set标记(初始化的)而不会被清除(pushdown只能针对非叶结点),因此maintain函数对于叶子来说并不会重复累加addv[o].
4. 查询信息:
1 int ql, qr; 2 void query(int o, int L, int R, int& ssum, int& smin, int &smax) { 3 int lc = o*2, rc = o*2+1; 4 maintain(o, L, R); // 处理被pushdown下来的标记 5 if(ql <= L && qr >= R) { 6 ssum = sumv[o]; 7 smin = minv[o]; 8 smax = maxv[o]; 9 } else { 10 pushdown(o); 11 int M = L + (R-L)/2; 12 int lsum = 0, lmin = INF, lmax = -INF; 13 int rsum = 0, rmin = INF, rmax = -INF; 14 if(ql <= M) query(lc, L, M, lsum, lmin, lmax); else maintain(lc, L, M); 15 if(qr > M) query(rc, M+1, R, rsum, rmin, rmax); else maintain(rc, M+1, R); 16 ssum = lsum + rsum; 17 smin = min(lmin, rmin); 18 smax = max(lmax, rmax); 19 } 20 }
也要维护信息和下推标记。
5. 初始化:
通过add,当然,也可以通过set.
1 memset(setv, 0, sizeof(setv)); //保证叶子结点的set标签 2 for(int i = 1; i <= n;i++) 3 { 4 scanf("%d", &v); 5 cl = cr = i; v = a[i]; op=2; 6 update(1, 1, n); 7 }
1 #include<bits/stdc++.h> 2 using namespace std; 3 4 const int INF = 0x3f3f3f3f; 5 const int maxn = 100000 + 10; 6 const int maxnode = maxn << 2; 7 int sumv[maxnode], minv[maxnode], maxv[maxnode], setv[maxnode], addv[maxnode]; 8 int n, a[maxn]; 9 10 // 维护信息 11 void maintain(int o, int L, int R) { 12 int lc = o*2, rc = o*2+1; 13 if(R > L) { 14 sumv[o] = sumv[lc] + sumv[rc]; 15 minv[o] = min(minv[lc], minv[rc]); 16 maxv[o] = max(maxv[lc], maxv[rc]); 17 } 18 if(setv[o] >= 0) { minv[o] = maxv[o] = setv[o]; sumv[o] = setv[o] * (R-L+1); } 19 if(addv[o]) { minv[o] += addv[o]; maxv[o] += addv[o]; sumv[o] += addv[o] * (R-L+1); } 20 } 21 22 // 标记传递 23 void pushdown(int o) { 24 int lc = o*2, rc = o*2+1; 25 if(setv[o] >= 0) { 26 setv[lc] = setv[rc] = setv[o]; 27 addv[lc] = addv[rc] = 0; 28 setv[o] = -1; // 清除本结点标记 29 } 30 if(addv[o]) { 31 addv[lc] += addv[o]; 32 addv[rc] += addv[o]; 33 addv[o] = 0; // 清除本结点标记 34 } 35 } 36 37 int op,cl,cr,v; 38 void update(int o, int L, int R) { 39 int lc = o*2, rc = o*2+1; 40 if(cl <= L && cr >= R) { // 标记修改 41 if(op == 2) addv[o] += v; 42 else { setv[o] = v; addv[o] = 0; } 43 } else { 44 pushdown(o); 45 int M = L + (R-L)/2; 46 if(cl <= M) update(lc, L, M); else maintain(lc, L, M); 47 if(cr > M) update(rc, M+1, R); else maintain(rc, M+1, R); 48 } 49 maintain(o, L, R); 50 } 51 52 int ql, qr; 53 void query(int o, int L, int R, int& ssum, int& smin, int &smax) { 54 int lc = o*2, rc = o*2+1; 55 maintain(o, L, R); // 处理被pushdown下来的标记 56 if(ql <= L && qr >= R) { 57 ssum = sumv[o]; 58 smin = minv[o]; 59 smax = maxv[o]; 60 } else { 61 pushdown(o); 62 int M = L + (R-L)/2; 63 int lsum = 0, lmin = INF, lmax = -INF; 64 int rsum = 0, rmin = INF, rmax = -INF; 65 if(ql <= M) query(lc, L, M, lsum, lmin, lmax); else maintain(lc, L, M); 66 if(qr > M) query(rc, M+1, R, rsum, rmin, rmax); else maintain(rc, M+1, R); 67 ssum = lsum + rsum; 68 smin = min(lmin, rmin); 69 smax = max(lmax, rmax); 70 } 71 } 72 73 void print_debug(int o, int L, int R) 74 { 75 printf("o:%d L:%d R:%d setv:%d addv:%d minv:%d\n", o, L, R, setv[o], addv[o], minv[o]); 76 if(R > L) 77 { 78 int M = L + (R - L) / 2; 79 print_debug(2*o, L, M); 80 print_debug(2*o+1, M+1, R); 81 } 82 } 83 84 int main() 85 { 86 memset(setv, 0, sizeof(setv)); 87 88 printf("数组元素个数:"); 89 scanf("%d", &n); 90 printf("数组元素:"); 91 for(int i = 1; i <= n;i++) 92 { 93 scanf("%d", &a[i]); 94 cl = cr = i; v = a[i]; op=2; 95 update(1, 1, n); 96 } 97 98 printf("1代表查询,2代表增加,3代表设置\n"); 99 printf("选择:"); 100 int chose; 101 while(scanf("%d", &chose) == 1) 102 { 103 if(chose == 1) 104 { 105 printf("查询的左右区间:"); 106 scanf("%d%d", &ql, &qr); 107 int ssum, smin, smax; 108 query(1, 1, n, ssum, smin, smax); 109 printf("最小值:%d\n", smin); 110 print_debug(1, 1, n); 111 } 112 else if(chose == 2) 113 { 114 printf("左右区间和增加值:"); 115 scanf("%d%d%d", &cl, &cr, &v); 116 op = 2; 117 update(1, 1, n); 118 print_debug(1, 1, n); 119 } 120 else 121 { 122 printf("左右区间和设置值:"); 123 scanf("%d%d%d", &cl, &cr, &v); 124 op = 3; 125 update(1, 1, n); 126 print_debug(1, 1, n); 127 } 128 printf("选择:"); 129 } 130 131 return 0; 132 }
个性签名:时间会解决一切