线段树合并
https://blog.csdn.net/keydou/article/details/83691189
bzoj 4756
题目大意:给出一棵树(根为 1),每个点有点权,对于每个点,询问它子树中点权比它大的点的个数
其实可以把这道题当成线段树合并入门题来做
首先把点权离散化,把所有的点都先单独建一颗权值线段树
然后从根开始 dfs ,一边 dfs 一边合并线段树,不断合并上去,最后求答案就行了
这种每个点都有权值、
#include<cstdio> #include<cstring> #include<algorithm> #define N 100005 using namespace std; int n,m,t,tot; int first[N],v[N],nxt[N]; int a[N],ans[N],val[N],root[N],size[N*40],lc[N*40],rc[N*40]; void add(int x,int y) { t++; nxt[t]=first[x]; first[x]=t; v[t]=y; } void build(int &root,int l,int r,int val) { root=++tot; size[root]=1; if(l==r) return; int mid=(l+r)>>1; if(val<=mid) build(lc[root],l,mid,val); else build(rc[root],mid+1,r,val); } int Merge(int x,int y) { if(!x) return y; if(!y) return x; size[x]+=size[y]; lc[x]=Merge(lc[x],lc[y]); rc[x]=Merge(rc[x],rc[y]); return x; } int calc(int root,int l,int r,int x,int y) { if(l>=x&&r<=y) return size[root]; int ans=0,mid=(l+r)>>1; if(x<=mid) ans+=calc(lc[root],l,mid,x,y); if(y>mid) ans+=calc(rc[root],mid+1,r,x,y); return ans; } void dfs(int x) { int i,j; for(i=first[x];i;i=nxt[i]) { j=v[i],dfs(j); root[x]=Merge(root[x],root[j]); } ans[x]=calc(root[x],1,m,val[x]+1,m); } int main() { int x,i; scanf("%d",&n); for(i=1;i<=n;++i) scanf("%d",&a[i]),val[i]=a[i]; for(i=2;i<=n;++i) scanf("%d",&x),add(x,i); sort(a+1,a+n+1),m=unique(a+1,a+n+1)-(a+1); for(i=1;i<=n;++i) { val[i]=lower_bound(a+1,a+m+1,val[i])-a; build(root[i],1,m,val[i]); } dfs(1); for(i=1;i<=n;++i) printf("%d\n",ans[i]); return 0; }
神奇的操作——线段树合并(例题: BZOJ2212)
什么是线段树合并?
首先你需要动态开点的线段树。(对每个节点维护左儿子、右儿子、存储的数据,然后要修改某儿子所在的区间中的数据的时候再创建该节点。)
考虑这样一个问题:
你现在有两棵权值线段树(大概是用来维护一个有很多数的可重集合那种线段树,若某节点对应区间是[l,r][l,r],则它存储的数据是集合中≥l≥l、≤r≤r的数的个数),现在你想把它们俩合并,得到一棵新的线段树。你要怎么做呢?
提供这样一种算法(tree(x, y, z)表示一个左儿子是x、右儿子是y、数据是z的新结点):
tree *merge(int l, int r, tree *A, tree *B){ if(A == NULL) return B; if(B == NULL) return A; if(l == r) return new tree(NULL, NULL, A -> data + B -> data); int mid = (l + r) >> 1; return new tree(merge(l, mid, A -> ls, B -> ls), merge(mid + 1, r, A -> rs, B -> rs), A -> data + B -> data); }
(上面的代码瞎写的……发现自己不会LaTeX写伪代码,于是瞎写了个“不伪的代码”,没编译过,凑付看 ><)
这个算法的复杂度是多少呢?显然是A、B两棵树重合的节点的个数。
那么假如你手里有m个只有一个元素的“权值线段树”,权值范围是[1,n][1,n],想都合并起来,复杂度是多少呢?复杂度是O(mlogn)O(mlogn)咯。
这个合并线段树的技巧可以解决一些问题——例如这个:BZOJ 2212。
题意:
给出一棵完全二叉树,每个叶子节点有一个权值,你可以任意交换任意节点的左右儿子,然后DFS整棵树得到一个叶子节点组成的序列,问这个序列的逆序对最少是多少。
可以看出,一个子树之内调换左右儿子,对子树之外的节点没有影响。于是可以DFS整棵树,对于一个节点的左右儿子,如果交换后左右儿子各出一个组成的逆序对更少则交换,否则不交换。如何同时求出交换与不交换左右儿子情况下的逆序对数量?可以使用线段树合并。
用两个权值线段树分别表示左右儿子中所有的数的集合。在合并两棵线段树的同时,A -> right_son
与B -> left_son
可以构成不交换左右儿子时的一些逆序对,A -> left_son
与B -> right_son
可以构成交换左右儿子时的一些逆序对,其余的逆序对在线段树A
、B
的左右子树中,可以在递归合并的时候处理掉。
只有叶子节点才有权值。
#include <cstdio> #include <cmath> #include <cstring> #include <algorithm> #define space putchar(' ') #define enter putchar('\n') using namespace std; typedef long long ll; template <class T> void read(T &x){ char c; bool op = 0; while(c = getchar(), c < '0' || c > '9') if(c == '-') op = 1; x = c - '0'; while(c = getchar(), c >= '0' && c <= '9') x = x * 10 + c - '0'; if(op) x = -x; } template <class T> void write(T x){ if(x < 0) putchar('-'), x = -x; if(x >= 10) write(x / 10); putchar(x % 10 + '0'); } const int N = 10000005; int n, tmp, ls[N], rs[N], data[N], tot; ll ans, res1, res2; int newtree(int l, int r, int x){ data[++tot] = 1; if(l == r) return tot; int mid = (l + r) >> 1, node = tot; if(x <= mid) ls[node] = newtree(l, mid, x); else rs[node] = newtree(mid + 1, r, x); return node; } int merge(int l, int r, int u, int v){ if(!u || !v) return u + v; if(l == r) return data[++tot] = data[u] + data[v], tot; int mid = (l + r) >> 1, node = ++tot; res1 += (ll)data[rs[u]] * data[ls[v]], res2 += (ll)data[ls[u]] * data[rs[v]]; ls[node] = merge(l, mid, ls[u], ls[v]); rs[node] = merge(mid + 1, r, rs[u], rs[v]); data[node] = data[ls[node]] + data[rs[node]]; return node; } int dfs(){ read(tmp); if(tmp) return newtree(1, n, tmp); int node = merge(1, n, dfs(), dfs()); ans += min(res1, res2); res1 = res2 = 0; return node; } int main(){ read(n); dfs(); write(ans), enter; return 0; }
如何找指定区间内存在的最大的数,类似于权值线段树的做法。
如果右子树有,先找右子树,其次才能去找左子树。
int query(int now,int L,int R,int i,int j,int d) { if(L==R)//达到端点,直接返回 { if(minv[now]<=d) return L; else return -1; } if(i<=L && R<=j)//对于已经包含了的情况,直接判断,也是吧*logn变成+logn的关键 if(minv[now]>d) return -1; int m=(L+R)>>1; if(j<=m) return query(lc[now],L,m,i,j,d); if(i>=m+1) return query(rc[now],m+1,R,i,j,d);//正常的筛出小区间 int t1=query(lc[now],L,m,i,j,d);//先左子树 if(t1!=-1) return t1; return query(rc[now],m+1,R,i,j,d);//左子树不行才弄右子树 }
https://nanti.jisuanke.com/t/41296
大致题意:定义连续区间,即满足区间最大值与最小值之差加一恰好等于区间数字个数的区间,是连续区间。现在给你一个数列,问你这个数列的所有子区间中有多少个连续区间。
对连续区间进行量化,有区间的max-min+1==cnt,其中cnt表示区间中数字的种类数。注意到,对于任意一个区间,恒有max-min+1>=cnt,我们对式子变形,有max-min-cnt>=-1。因此,如果我们维护每一个区间的max-min-cnt的最小值,当这个最小值为-1的时候,说明存在有连续区间,这个最小值出现的次数就是连续区间的个数。那么问题就是如何维护所有区间的最小值。
注意到,维护max-min-cnt,这里面有一个cnt。根据以前的经验,遇到区间数字种类的问题,一般都是枚举右端点R,然后线段树维护1..L,表示以每一个位置为左端点的区间的数值。同样这题我们也可以利用这种方法,枚举右端点,然后更新,看看有多少个最小值为-1的区间,累积求和就是最后的解。
那么,就是要看怎么更新了。对于最大值,我们可以维护一个单调栈,如果当前新加入的数字比栈顶元素大,那么从栈顶元素出现的位置开始到新加入的数字之前进行修改,增加上二者的差值,然后退栈,继续与前一个比较更新。最小值也是类似,不一一赘述了。然后就是cnt,这个就是常规操作,对于区间[last[x],i]进行修改即可。最后线段树维护区间最小值以及区间最小值的数目,这个也很容易实现。具体见代码:
#include<bits/stdc++.h> using namespace std; const int N = 1e5 + 5; typedef long long ll; typedef pair<ll, ll> pii; int t, n; ll a[N]; ll sm[N], bi[N]; map<int, int> mp; ll mi[N << 2], lazy[N << 2]; ll cnt[N << 2]; void pushup(int rt) { if(mi[rt << 1] == mi[rt << 1 | 1]) { mi[rt] = mi[rt << 1]; cnt[rt] = cnt[rt << 1] + cnt[rt << 1 | 1]; } else if(mi[rt << 1] < mi[rt << 1 | 1]) { mi[rt] = mi[rt << 1]; cnt[rt] = cnt[rt << 1]; } else { mi[rt] = mi[rt << 1 | 1]; cnt[rt] = cnt[rt << 1 | 1]; } } void build(int rt, int l, int r) { lazy[rt] = 0; if(l == r) { mi[rt] = 0; cnt[rt] = 1; return; } int m = (l + r) >> 1; build(rt << 1, l, m); build(rt << 1 | 1, m + 1, r); pushup(rt); } void pushdown(int rt) { if(lazy[rt]) { lazy[rt << 1] += lazy[rt]; lazy[rt << 1 | 1] += lazy[rt]; mi[rt << 1] += lazy[rt]; mi[rt << 1 | 1] += lazy[rt]; lazy[rt] = 0; } } void update(int rt, int l, int r, int le, int re, ll val) { if(le <= l && re >= r) { mi[rt] += val; lazy[rt] += val; return; } pushdown(rt); int m = (l + r) >> 1; if(re <= m) update(rt << 1, l, m, le, re, val); else if(le > m) update(rt << 1 | 1, m + 1, r, le, re, val); else { update(rt << 1, l, m, le, m, val); update(rt << 1 | 1, m + 1, r, m + 1, re, val); } pushup(rt); } int main() { scanf("%d", &t); int ca = 0; while(t--) { mp.clear(); scanf("%d", &n); for(int i = 1; i <= n; i++) scanf("%lld", &a[i]); build(1, 1, n); ll ans = 0; int t1 = 0, t2 = 0; for(int i = 1; i <= n; i++) { while(t1 && a[bi[t1] ] < a[i]) { update(1, 1, n, bi[t1 - 1] + 1, bi[t1], a[i] - a[bi[t1] ]); t1--; } bi[++t1] = i; while(t2 && a[sm[t2] ] > a[i]) { update(1, 1, n, sm[t2 - 1] + 1, sm[t2], a[sm[t2] ] - a[i]); t2--; } sm[++t2] = i; update(1, 1, n, mp[a[i] ] + 1, i, -1); mp[a[i] ] = i; if(mi[1] == -1) ans += cnt[1]; } printf("Case #%d: %lld\n", ++ca, ans); } return 0; }