BZOJ 1146 二分+链剖+线段树+treap
思路:
恶心的数据结构题……
首先 我们 链剖 把树 变成序列 再 套一个 区间 第K大就好了……
复杂度(n*log^4n)
//By SiriusRen
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define N 88888
#define inf 100000000
int n,q,first[N],next[N*2],v[N*2],t[N],tot,op,xx,yy;
int fa[N],son[N],deep[N],top[N],siz[N],cnt,ch[N];
int root[N*40],size;
struct Treap{int ch[2],sz,cnt,v,rnd;}tr[N*40];
void Upd(int k){tr[k].sz=tr[tr[k].ch[0]].sz+tr[tr[k].ch[1]].sz+tr[k].cnt;}
void Rot(int &k,bool f){int t=tr[k].ch[f];tr[k].ch[f]=tr[t].ch[!f],tr[t].ch[!f]=k,Upd(k),Upd(t),k=t;}
void Insert(int &k,int num){
if(!k){k=++size;tr[k].sz=tr[k].cnt=1,tr[k].rnd=rand();tr[k].v=num;return;}
tr[k].sz++;
if(tr[k].v==num){tr[k].cnt++;return;}
bool f=num>tr[k].v;
Insert(tr[k].ch[f],num);
if(tr[tr[k].ch[f]].rnd<tr[k].rnd)Rot(k,f);
}
void Del(int &k,int num){
if(tr[k].v==num){
if(tr[k].cnt>1)tr[k].cnt--,tr[k].sz--;
else if(tr[k].ch[0]*tr[k].ch[1]==0)k=max(tr[k].ch[0],tr[k].ch[1]);
else Rot(k,tr[tr[k].ch[0]].rnd>tr[tr[k].ch[1]].rnd),Del(k,num);
}
else tr[k].sz--,Del(tr[k].ch[num>tr[k].v],num);
}
int get_rank(int k,int num){
if(!k)return 0;
if(tr[k].v==num)return tr[tr[k].ch[1]].sz;
else if(tr[k].v<num)return get_rank(tr[k].ch[1],num);
else return get_rank(tr[k].ch[0],num)+tr[tr[k].ch[1]].sz+tr[k].cnt;
}
void insert(int l,int r,int pos,int num,int wei){
Insert(root[pos],wei);
if(l==r)return;
int mid=(l+r)>>1,lson=pos<<1,rson=pos<<1|1;
if(mid<num)insert(mid+1,r,rson,num,wei);
else insert(l,mid,lson,num,wei);
}
void change(int l,int r,int pos,int num,int wei){
Del(root[pos],t[xx]),Insert(root[pos],wei);
if(l==r)return;
int mid=(l+r)>>1,lson=pos<<1,rson=pos<<1|1;
if(mid<num)change(mid+1,r,rson,num,wei);
else change(l,mid,lson,num,wei);
}
int query(int l,int r,int pos,int L,int R,int num){
if(l>=L&&r<=R)return get_rank(root[pos],num);
int mid=(l+r)>>1,lson=pos<<1,rson=pos<<1|1;
if(mid<L)return query(mid+1,r,rson,L,R,num);
else if(mid>=R)return query(l,mid,lson,L,R,num);
else return query(l,mid,lson,L,R,num)+query(mid+1,r,rson,L,R,num);
}
void Add(int x,int y){v[tot]=y,next[tot]=first[x],first[x]=tot++;}
void add(int x,int y){Add(x,y),Add(y,x);}
void dfs(int x){
siz[x]=1;
for(int i=first[x];~i;i=next[i])
if(v[i]!=fa[x]){
fa[v[i]]=x,deep[v[i]]=deep[x]+1;
dfs(v[i]),siz[x]+=siz[v[i]];
if(siz[v[i]]>siz[son[x]])son[x]=v[i];
}
}
void dfs2(int x,int tp){
top[x]=tp,ch[x]=++cnt;
insert(1,n,1,cnt,t[x]);
if(son[x])dfs2(son[x],tp);
for(int i=first[x];~i;i=next[i])
if(v[i]!=fa[x]&&v[i]!=son[x])
dfs2(v[i],v[i]);
}
int find(int x,int y,int num){
int fx=top[x],fy=top[y],tmp=0;
while(fx!=fy){
if(deep[fx]<deep[fy])swap(fx,fy),swap(x,y);
tmp+=query(1,n,1,ch[fx],ch[x],num);
x=fa[fx],fx=top[x];
}
if(deep[x]>deep[y])swap(x,y);
return tmp+query(1,n,1,ch[x],ch[y],num);
}
void b_srch(){
int l=0,r=inf,ans;
while(l<=r){
int mid=(l+r)>>1;
if(find(xx,yy,mid)>=op)l=mid+1;
else ans=mid,r=mid-1;
}
if(!ans)puts("invalid request!");
else printf("%d\n",ans);
}
int main(){
memset(first,-1,sizeof(first));
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++)scanf("%d",&t[i]);
for(int i=1;i<n;i++)scanf("%d%d",&xx,&yy),add(xx,yy);
dfs(1),dfs2(1,1);
for(int i=1;i<=q;i++){
scanf("%d%d%d",&op,&xx,&yy);
if(op)b_srch();
else change(1,n,1,ch[xx],yy),t[xx]=yy;
}
}