bzoj 2588: Spoj 10628. Count on a tree【主席树+倍增】
算是板子,把值离散化,每个点到跟上做主席树,然后查询的时候主席树上用u+v-lca-fa[lca]的值二分
#include<iostream>
#include<cstdio>
#include<map>
#include<algorithm>
using namespace std;
const int N=100005;
int n,m,h[N],cnt,tot,la,a[N],ha[N],b[N],has,f[N][30],rt[N],ind,po[N],nu[N],de[N];
map<int,int>mp;
struct qwe
{
int ne,to;
}e[N<<1];
struct zhuxishu
{
int ls,rs,sum;
}t[2200005];
int read()
{
int r=0,f=1;
char p=getchar();
while(p>'9'||p<'0')
{
if(p=='-')
f=-1;
p=getchar();
}
while(p>='0'&&p<='9')
{
r=r*10+p-48;
p=getchar();
}
return r*f;
}
void add(int u,int v)
{
cnt++;
e[cnt].ne=h[u];
e[cnt].to=v;
h[u]=cnt;
}
void dfs(int u,int fat)
{
f[u][0]=fat;
nu[++ind]=u;
po[u]=ind;
de[u]=de[fat]+1;
for(int i=h[u];i;i=e[i].ne)
if(e[i].to!=fat)
dfs(e[i].to,u);
}
void update(int l,int r,int pr,int &ro,int w)
{
ro=++tot;
t[ro].sum=t[pr].sum+1;
if(l==r)
return;
t[ro].ls=t[pr].ls;
t[ro].rs=t[pr].rs;
int mid=(l+r)>>1;
if(w<=mid)
update(l,mid,t[pr].ls,t[ro].ls,w);
else
update(mid+1,r,t[pr].rs,t[ro].rs,w);
}
int lca(int x,int y)
{
if(de[x]<de[y])
swap(x,y);
for(int i=16;i>=0;i--)
if((1<<i)&(de[x]-de[y]))
x=f[x][i];
for(int i=16;i>=0;i--)
if(f[x][i]!=f[y][i])
x=f[x][i],y=f[y][i];
return x==y?x:f[x][0];
}
int ques(int x,int y,int k)
{
int a=x,b=y,c=lca(a,b),d=f[c][0];
a=rt[po[a]],b=rt[po[b]],c=rt[po[c]],d=rt[po[d]];
int l=1,r=has;
while(l<r)
{
int mid=(l+r)>>1;
int now=t[t[a].ls].sum+t[t[b].ls].sum-t[t[c].ls].sum-t[t[d].ls].sum;
if(now>=k)
{
r=mid;
a=t[a].ls,b=t[b].ls,c=t[c].ls,d=t[d].ls;
}
else
{
k-=now;
l=mid+1;
a=t[a].rs,b=t[b].rs,c=t[c].rs,d=t[d].rs;
}
}
return ha[l];
}
int main()
{
n=read(),m=read();
for(int i=1;i<=n;i++)
a[i]=read(),b[i]=a[i];
sort(b+1,b+1+n);
for(int i=1;i<=n;i++)
if(i==1||b[i]!=b[i-1])
mp[b[i]]=++has,ha[has]=b[i];
for(int i=1;i<=n;i++)
a[i]=mp[a[i]];
for(int i=1;i<n;i++)
{
int x=read(),y=read();
add(x,y);add(y,x);
}
dfs(1,0);
for(int j=1;j<=16;j++)
for(int i=1;i<=n;i++)
f[i][j]=f[f[i][j-1]][j-1];
for(int i=1;i<=n;i++)
update(1,has,rt[po[f[nu[i]][0]]],rt[i],a[nu[i]]);
for(int i=1;i<=m;i++)
{
int x=read(),y=read(),k=read();
x^=la;
la=ques(x,y,k);
printf("%d",la);
if(i!=m)
puts("");
}
return 0;
}