UVa 1232 / LA 4108 线段树
这题就是个裸的线段树。。但细节容易想错。。
题意:一个全0的序列,m次操作,每次给出一个区间[l,r)和一个值v,将该区间内所有小于等于v的数全部修改为v。求总的修改次数。
怎么做呢?一开始我是这么做的:开一个线段树,每个节点维护一个值:该区间内的元素的值——如果该区间内元素值不同,则置为-1,然后直接统计。(而且一开始我居然把1-4*maxn内的所有点都初始化了一遍。。)这样做最坏情况下每次都要O(4*n),当然华丽丽地T。
怎么优化呢?其实很好想:如果当前区间内的最小值都比v大,那么就无需继续了;如果当前区间内的最大值都比v小,那就直接把整个区间改成v好了;如果v介于maxv和minv之间,那么继续往下找。这样一来,连这个区间内的元素值是否相同这个标记都省去了——因为如果相同,则有maxv=minv,v要么比这个值大,要么比这个值小,不会再递归下去了。
Attention!说过了,这道题细节很容易想错:我们应该在哪些地方调用update、pushdown和maintain?注意这道题的maintain不是通常意义上的maintain——后者是用父亲维护儿子,而前者是用儿子维护父亲。所以,这里的maintain用pushup更为恰当。所以,update伪代码如下:
注意,对于没有递归访问的子树,没有调用push_up,否则将造成无中生有的节点。
// LA4108 #include <cstdio> #include <cstring> #include <algorithm> using namespace std; const int N=120000, INF=0x6f6f6f6f; #define rep(i,a,b) for (int i=a; i<=b; i++) int a[N], b[N], t[N], minv[4*N], maxv[4*N], y1, y2, h, len; void read(int &r) { int x=0, f=1; char ch=getchar(); while (ch<'0' || ch>'9') { if(ch=='-') f=-1; ch=getchar(); } while (ch>='0' && ch<='9') { x=x*10+ch-'0'; ch=getchar(); } r=x*f; } void build(int o, int L, int R) { minv[o]=maxv[o]=0; if (L==R) return; int M=(L+R)>>1, lc=o<<1, rc=lc|1; if (L<=M) build(lc, L, M); if (M<R) build(rc, M+1, R); } void maintain(int o, int L, int R) { if (L>=R) return; int lc=o<<1, rc=lc|1; maxv[o]=max(maxv[lc], maxv[rc]); minv[o]=min(minv[lc], minv[rc]); } void pushdown(int o) { if (maxv[o]!=minv[o]) return; int lc=o<<1, rc=lc|1; maxv[lc]=minv[lc]=maxv[rc]=minv[rc]=maxv[o]; } void update(int o, int L, int R) { int lc=o<<1, rc=lc|1, ok=(y1<=L && R<=y2); if (h<minv[o]) return; if (ok && maxv[o]<=h) { maxv[o]=minv[o]=h; len+=(R-L+1); return; } if (L<R) pushdown(o); int M=(L+R)>>1; if (y1<=M) update(lc, L, M); if (M<y2) update(rc, M+1, R); maintain(o, L, R); } int T, ans, n, minl, maxr; int main() { read(T); while (T--) { read(n); minl=INF, maxr=-INF, ans=0; rep(i,1,n) { read(a[i]), read(b[i]), read(t[i]); b[i]--; minl=min(minl, a[i]); maxr=max(maxr, b[i]); } build(1, minl, maxr); maxv[1]=minv[1]=0; rep(i,1,n) { y1=a[i], y2=b[i], h=t[i], len=0; update(1, minl, maxr); ans+=len; } printf("%d\n", ans); } return 0; }