P7518-[省选联考2021A/B卷]宝石【主席树,二分】
正题
题目链接:https://www.luogu.com.cn/problem/P7518
题目大意
给出\(n\)个点的一棵树,每个点上有不大于\(m\)的数字。
然后给出一个长度为\(c\)的各个位数不同的序列,每次询问一条路径上找到一个最大的\(k\)使得该序列的存在\(1\sim k\)的子序列。
\(1\leq n,q\leq 2\times 10^5,1\leq c\leq m\leq 5\times 10^4,1\leq w_i\leq m\)
解题思路
传统的思想,路径分为向上和向下的两部分。然后因为序列没有重复元素,所以相当于对于每一种存在于序列的宝石都有唯一的下一种宝石。
先考虑向上的,发现我们必须从一开始,所以其实我们可以考虑离线记录一个\(last\)数组,其中\(last_i\)表示到根节点的路径中上一个\(i\)类型的是什么。
然后每个节点维护一棵线段树,对于节点\(x\)若是第\(i\)种宝石,那么第\(j\)个位置就储存它往上走到按顺序第\(i\sim j\)颗宝石的最大深度,这个可以每次从\(last_{i+1}\)处继承一棵树然后修改一个位置就好了。
然后询问的时候就直接从\(last_1\)处的树上二分出我们需要深度就可以确定我们往上走的路径能走到哪里了。
考虑向下的路径,我们把它拆成一条反向向上的路径,但是起点不是固定的,所以我们可以直接二分答案,然后在\(last_{mid}\)处向上走到\(LCA\)时,查看是否上和下的路径的序列有重复部分就好了。
时间复杂度\(O(n\log^2 n)\)
code
考场代码比较凌乱
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<cctype>
using namespace std;
const int N=2e5+10,T=18;
struct edge{
int to,next;
}a[N<<1];
int n,m,c,tot,w[N],p[N],ls[N],ans[N],lca[N];
int f[N][T+1],dep[N],las[N],rev[N],rt[N],up[N];
vector<int> vs[N],vt[N];
int read(){
int x=0,f=1;char c=getchar();
while(!isdigit(c)){if(c=='-')f=-f;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+c-'0';c=getchar();}
return x*f;
}
struct SegTree{
int cnt,w[N*20],ls[N*20],rs[N*20];
int Change(int x,int L,int R,int pos,int val){
int now=++cnt;w[now]=max(w[x],val);
if(L==R){ls[now]=rs[now]=0;return now;}
int mid=(L+R)>>1;
if(pos<=mid)ls[now]=Change(ls[x],L,mid,pos,val),rs[now]=rs[x];
else rs[now]=Change(rs[x],mid+1,R,pos,val),ls[now]=ls[x];
return now;
}
int Ask(int x,int L,int R,int k){
if(!x)return 0;
if(L==R)return L;int mid=(L+R)>>1;
if(w[rs[x]]<k)return Ask(ls[x],L,mid,k);
return Ask(rs[x],mid+1,R,k);
}
int Bsk(int x,int L,int R,int k){
if(!x)return c+1;
if(L==R)return L;int mid=(L+R)>>1;
if(w[ls[x]]<k)return Bsk(rs[x],mid+1,R,k);
return Bsk(ls[x],L,mid,k);
}
}Tr;
void addl(int x,int y){
a[++tot].to=y;
a[tot].next=ls[x];
ls[x]=tot;return;
}
void dfs(int x,int fa){
f[x][0]=fa;dep[x]=dep[fa]+1;
for(int i=ls[x];i;i=a[i].next){
int y=a[i].to;
if(y==fa)continue;
dfs(y,x);
}
return;
}
int LCA(int x,int y){
if(dep[x]>dep[y])swap(x,y);
for(int i=T;i>=0;i--)
if(dep[f[y][i]]>=dep[x])y=f[y][i];
if(x==y)return x;
for(int i=T;i>=0;i--)
if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
return f[x][0];
}
void calc(int x,int fa){
int P=p[w[x]];
if(P){
rev[x]=las[P];las[P]=x;
rt[x]=Tr.Change(rt[las[P+1]],0,c,P,dep[x]);
}
for(int i=0;i<vs[x].size();i++){
int id=vs[x][i];
up[id]=Tr.Ask(rt[las[1]],0,c,dep[lca[id]]);
}
for(int i=ls[x];i;i=a[i].next){
int y=a[i].to;
if(y==fa)continue;
calc(y,x);
}
if(P)las[P]=rev[x];
return;
}
void solve(int x,int fa){
int P=p[w[x]];
if(P){
rev[x]=las[P];las[P]=x;
rt[x]=Tr.Change(rt[las[P-1]],1,c+1,P,dep[x]);
}
for(int i=0;i<vt[x].size();i++){
int id=vt[x][i],l=up[id]+1,r=c;
if(lca[id]==x){ans[id]=up[id];continue;}
while(l<=r){
int mid=(l+r)>>1;
int tmp=Tr.Bsk(rt[las[mid]],1,c+1,dep[lca[id]]+1);
if(tmp<=up[id]+1)l=mid+1;
else r=mid-1;
}
ans[id]=r;
}
for(int i=ls[x];i;i=a[i].next){
int y=a[i].to;
if(y==fa)continue;
solve(y,x);
}
if(P)las[P]=rev[x];
return;
}
int main()
{
n=read();m=read();c=read();
for(int i=1;i<=c;i++){
int x=read();p[x]=i;
}
for(int i=1;i<=n;i++)w[i]=read();
for(int i=1;i<n;i++){
int x=read(),y=read();
addl(x,y);addl(y,x);
}
dfs(1,0);
for(int j=1;j<=T;j++)
for(int i=1;i<=n;i++)
f[i][j]=f[f[i][j-1]][j-1];
m=read();
for(int i=1;i<=m;i++){
int s=read(),t=read();
lca[i]=LCA(s,t);
vs[s].push_back(i);
vt[t].push_back(i);
}
calc(1,1);Tr.cnt=0;
solve(1,1);
for(int i=1;i<=m;i++)
printf("%d\n",ans[i]);
return 0;
}
/*
7 3 3
2 3 1
2 1 3 3 2 1 3
1 2
2 3
1 4
4 5
4 6
6 7
5
3 5
1 3
7 3
5 7
7 5
*/