神奇的操作——线段树合并(例题: BZOJ2212)

什么是线段树合并?

首先你需要动态开点的线段树。(对每个节点维护左儿子、右儿子、存储的数据,然后要修改某儿子所在的区间中的数据的时候再创建该节点。)

考虑这样一个问题:

你现在有两棵权值线段树(大概是用来维护一个有很多数的可重集合那种线段树,若某节点对应区间是\([l, r]\),则它存储的数据是集合中\(\ge l\)\(\le r\)的数的个数),现在你想把它们俩合并,得到一棵新的线段树。你要怎么做呢?

提供这样一种算法(tree(x, y, z)表示一个左儿子是x、右儿子是y、数据是z的新结点):

tree *merge(int l, int r, tree *A, tree *B){
	if(A == NULL) return B;
	if(B == NULL) return A;
	if(l == r) return new tree(NULL, NULL, A -> data + B -> data);
	int mid = (l + r) >> 1;
	return new tree(merge(l, mid, A -> ls, B -> ls), merge(mid + 1, r, A -> rs, B -> rs), A -> data + B -> data);
}

(上面的代码瞎写的……发现自己不会LaTeX写伪代码,于是瞎写了个“不伪的代码”,没编译过,凑付看 ><)

这个算法的复杂度是多少呢?显然是A、B两棵树重合的节点的个数。

那么假如你手里有m个只有一个元素的“权值线段树”,权值范围是\([1, n]\),想都合并起来,复杂度是多少呢?复杂度是\(O(m\log n)\)咯。

这个合并线段树的技巧可以解决一些问题——例如这个:BZOJ 2212

题意:

给出一棵完全二叉树,每个叶子节点有一个权值,你可以任意交换任意节点的左右儿子,然后DFS整棵树得到一个叶子节点组成的序列,问这个序列的逆序对最少是多少。

可以看出,一个子树之内调换左右儿子,对子树之外的节点没有影响。于是可以DFS整棵树,对于一个节点的左右儿子,如果交换后左右儿子各出一个组成的逆序对更少则交换,否则不交换。如何同时求出交换与不交换左右儿子情况下的逆序对数量?可以使用线段树合并。

用两个权值线段树分别表示左右儿子中所有的数的集合。在合并两棵线段树的同时,A -> right_sonB -> left_son可以构成不交换左右儿子时的一些逆序对,A -> left_sonB -> right_son可以构成交换左右儿子时的一些逆序对,其余的逆序对在线段树AB的左右子树中,可以在递归合并的时候处理掉。

#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#define space putchar(' ')
#define enter putchar('\n')
using namespace std;
typedef long long ll;
template <class T>
void read(T &x){
    char c;
    bool op = 0;
    while(c = getchar(), c < '0' || c > '9')
        if(c == '-') op = 1;
    x = c - '0';
    while(c = getchar(), c >= '0' && c <= '9')
        x = x * 10 + c - '0';
    if(op) x = -x;
}
template <class T>
void write(T x){
    if(x < 0) putchar('-'), x = -x;
    if(x >= 10) write(x / 10);
    putchar(x % 10 + '0');
}

const int N = 10000005;
int n, tmp, ls[N], rs[N], data[N], tot;
ll ans, res1, res2;

int newtree(int l, int r, int x){
    data[++tot] = 1;
    if(l == r) return tot;
    int mid = (l + r) >> 1, node = tot;
    if(x <= mid) ls[node] = newtree(l, mid, x);
    else rs[node] = newtree(mid + 1, r, x);
    return node;
}
int merge(int l, int r, int u, int v){
    if(!u || !v) return u + v;
    if(l == r) return data[++tot] = data[u] + data[v], tot;
    int mid = (l + r) >> 1, node = ++tot;
    res1 += (ll)data[rs[u]] * data[ls[v]], res2 += (ll)data[ls[u]] * data[rs[v]];
    ls[node] = merge(l, mid, ls[u], ls[v]);
    rs[node] = merge(mid + 1, r, rs[u], rs[v]);
    data[node] = data[ls[node]] + data[rs[node]];
    return node;
}
int dfs(){
    read(tmp);
    if(tmp) return newtree(1, n, tmp);
    int node = merge(1, n, dfs(), dfs());
    ans += min(res1, res2);
    res1 = res2 = 0;
    return node;
}

int main(){
    read(n);
    dfs();
    write(ans), enter;
    return 0;
}
posted @ 2018-03-06 18:41  胡小兔  阅读(4873)  评论(2编辑  收藏  举报