今天心血来潮,突然想到有主席树这个神奇的玩意儿。。。一直都只是听说也没敢看。(蒟蒻蛋蛋的忧伤。。。)
然后到网上翻大神的各种解释。。。看了半天。。。
一拍脑袋。。。哇其实主席树
真的难。。。【咳咳我只是来搞笑的】
看了很多种解释最后一头雾水啊。。。就是没法脑补出(嗯没错经常脑补数据结构长啥样)主席树的样子。。。。
最后终于找到了一个大大大大大大神犇的ppt,看到了主席树的真面目,才真的能弄懂主席树的结构。。。
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
下面进入正题【废话多】
主席树比较简单的理解就是,给你一个长为n的数列,你对于其中每个区间[1,i]都建立一棵线段树,用于保存这个区间中每个数出现的次数(这个显然线段树统计个数吧)。
然后这样你就得到了n棵线段树,如果按正常存储。。。MLE吧
不知道大佬用什么做线段树,反正我是数组模拟的。。。这样预先开好肯定会炸。。。
那么解决的办法呢,当然是有的【又是废话。。。我该改改这毛病了。。。】
用指针建树,就可以消灾解难了。为什么呢,仔细想想,如果说a[i]很小,那么就可以从tree[i-1]中得到很多重复的部分。譬如这n个数字是1~10,a[i]=3,那么显然tree[i]的右半棵子树和tree[i-1]一模一样,就可以直接使用。
如此算来,每加一个数,只要增加log(size)个节点 【注:size为n个数中不重复的个数】。那么只要O(n log n)就够了。
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
如果你脑中实在脑补不出(像我一样),那么下面就是图解:
【图中黑色的是原本存在的点,红色是新增的点,可见新增的点权值在(左数第二个黑点)和(最右边的黑点)之间。】
这样一来,应该可以有点概念了吧。。
如此看来,主席树还不算太难理解吧??
先看构造的数据结构。。。
1 struct node 2 { 3 node *son[2]; 4 int cnt; 5 node(){son[0]=son[1]=NULL;cnt=0;} 6 void update() 7 { 8 if (son[0]) cnt+=son[0]->cnt; 9 if (son[1]) cnt+=son[1]->cnt; 10 } 11 }*null=new node(),*root[200011]={NULL},q[1850011];
那么对于主席树的构建,也是好理解的了。只要连边或者加点就好了,复杂度毋庸置疑是O(log n)
1 void build(node *&y,node *&x,int l,int r,int tmp) 2 { 3 if (x==NULL) x=null; 4 y=&q[++qt]; 5 *y=node(); 6 int mid=l+r>>1; 7 if (l==r) 8 { 9 *y=*x; 10 y->cnt++; 11 return; 12 } 13 if (tmp<=a2[mid]) 14 { 15 build(y->son[0],x->son[0],l,mid,tmp); 16 y->son[1]=x->son[1]; 17 y->update(); 18 } 19 else 20 { 21 build(y->son[1],x->son[1],mid+1,r,tmp); 22 y->son[0]=x->son[0]; 23 y->update(); 24 } 25 }
还有关于查找第k大的方法,和线段树是类似的,由于主席树一个重要的性质就是可减,所以只要同时计算tree[l-1]和tree[r]的同区间节点并相减,然后和k比较,往下递归寻找,直到叶节点为止。复杂度O(log n)。
1 void find(node *&x1,node *&x2,int l,int r,int k) 2 { 3 if (x1==NULL) x1=null; 4 if (x2==NULL) x2=null; 5 if (l==r) 6 { 7 cout << a2[l] << "\n";return; 8 } 9 int mid=l+r>>1,hs=0; 10 if (x2->son[0]) hs+=x2->son[0]->cnt; 11 if (x1->son[0]) hs-=x1->son[0]->cnt; 12 if (hs>=k) find(x1->son[0],x2->son[0],l,mid,k); 13 else find(x1->son[1],x2->son[1],mid+1,r,k-hs); 14 }
大概就是这么个套路。。。
噢对了还有很重要的一点!!!!
对于size的获得与比较,都应在有序的基础上。所以可以先复制一个数组sort一下,然后unique一下,(就是个离散化)。
p党真不好意思我想你们可能不会看到这里了不过。。。就是个排序去重的意思。
那么对于裸题poj2104而言,代码如下(也许会和神犇有点像。。。)
1 #include <iostream> 2 #include <cstdio> 3 #include <algorithm> 4 using namespace std; 5 6 int n,m,a[200011],a2[200011]; 7 struct node 8 { 9 node *son[2]; 10 int cnt; 11 node(){son[0]=son[1]=NULL;cnt=0;} 12 void update() 13 { 14 if (son[0]) cnt+=son[0]->cnt; 15 if (son[1]) cnt+=son[1]->cnt; 16 } 17 }*null=new node(),*root[200011]={NULL},q[1850011]; 18 int qt; 19 20 int read() 21 { 22 int x=0,f=1;char c=getchar(); 23 while (!isdigit(c)) {if (c=='-')f=-1;c=getchar();} 24 while (isdigit(c)) {x=x*10+c-'0';c=getchar();} 25 return x*f; 26 } 27 28 void build(node *&y,node *&x,int l,int r,int tmp) 29 { 30 if (x==NULL) x=null; 31 y=&q[++qt]; 32 *y=node(); 33 int mid=l+r>>1; 34 if (l==r) 35 { 36 *y=*x; 37 y->cnt++; 38 return; 39 } 40 if (tmp<=a2[mid]) 41 { 42 build(y->son[0],x->son[0],l,mid,tmp); 43 y->son[1]=x->son[1]; 44 y->update(); 45 } 46 else 47 { 48 build(y->son[1],x->son[1],mid+1,r,tmp); 49 y->son[0]=x->son[0]; 50 y->update(); 51 } 52 } 53 54 void find(node *&x1,node *&x2,int l,int r,int k) 55 { 56 if (x1==NULL) x1=null; 57 if (x2==NULL) x2=null; 58 if (l==r) 59 { 60 cout << a2[l] << "\n";return; 61 } 62 int mid=l+r>>1,hs=0; 63 if (x2->son[0]) hs+=x2->son[0]->cnt; 64 if (x1->son[0]) hs-=x1->son[0]->cnt; 65 if (hs>=k) find(x1->son[0],x2->son[0],l,mid,k); 66 else find(x1->son[1],x2->son[1],mid+1,r,k-hs); 67 } 68 69 int main() 70 { 71 null->son[0]=null;null->son[1]=null; 72 n=read();m=read(); 73 for (int i=1;i<=n;i++) 74 { 75 a[i]=read(); 76 a2[i]=a[i]; 77 } 78 sort(a2+1,a2+1+n); 79 int sz=unique(a2+1,a2+1+n)-(a2+1); 80 for (int i=1;i<=n;i++) 81 build(root[i],root[i-1],1,sz,a[i]); 82 for (int i=1;i<=m;i++) 83 { 84 int ll,rr,kk; 85 ll=read(),rr=read(),kk=read(); 86 find(root[ll-1],root[rr],1,sz,kk); 87 } 88 return 0; 89 }