【HNOI2016】树
【HNOI2016】树
每一个复制过来的子树(我们称为一个树团)有用的只有需要被访问的节点,包括根,根的父亲,要询问的点。我们只需要求出这些点到其所在树团根的距离以及倍增数组就好了。
需要讨论一些不同的情况。
然而我头铁,写了虚树,时间/空间常数大到自闭(不敢乱写虚树了)。
要注意的细节是,我维护了两个数组:\(dep_v\)表示\(v\)到根的节点数;\(dis_v\)表示\(v\)到根的路径数。这两个数组不能混用。
代码:
#include<bits/stdc++.h>
#define ll long long
#define N 100005
using namespace std;
inline ll Get() {ll x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}
int n,m,Q;
struct road {int to,next;}s[N<<1];
int h[N],cnt;
void add(int i,int j) {s[++cnt]=(road) {j,h[i]};h[i]=cnt;}
int lx,rx;
int rt[N],sum[N*18];
int ls[N*18],rs[N*18];
int tot;
void Insert(int &v,int old,int lx,int rx,int p) {
v=++tot;
sum[v]=sum[old]+1;
ls[v]=ls[old];
rs[v]=rs[old];
if(lx==rx) return ;
int mid=lx+rx>>1;
if(p<=mid) Insert(ls[v],ls[old],lx,mid,p);
else Insert(rs[v],rs[old],mid+1,rx,p);
}
int query_sum(int a,int b,int lx,int rx,int lim) {
if(lx>lim) return 0;
if(rx<=lim) return sum[b]-sum[a];
int mid=lx+rx>>1;
return query_sum(ls[a],ls[b],lx,mid,lim)+query_sum(rs[a],rs[b],mid+1,rx,lim);
}
int Find_kth(int a,int b,int k,int lx,int rx) {
if(lx==rx) return lx;
int mid=lx+rx>>1;
if(k<=sum[ls[b]]-sum[ls[a]]) return Find_kth(ls[a],ls[b],k,lx,mid);
else return Find_kth(rs[a],rs[b],k-sum[ls[b]]+sum[ls[a]],mid+1,rx);
}
int dep[int(N*6.7)],size[N];
ll dis[int(N*6.7)];
int dfn[N],edf[N],dfn_id;
int lst[N];
struct Integer {
typedef unsigned char byte;
byte a,b,c;
operator int(){return int(a)<<16|int(b)<<8|int(c);}
Integer(int val=0){
a=val>>16;
b=(val>>8)&0xff;
c=val&0xff;
}
Integer operator + (Integer b){return Integer(int(*this)+int(b));}
Integer operator - (Integer b){return Integer(int(*this)-int(b));}
Integer operator * (Integer b){return Integer(int(*this)*int(b));}
Integer operator / (Integer b){return Integer(int(*this)/int(b));}
Integer operator = (Integer b){
this->a=b.a;
this->b=b.b;
this->c=b.c;
return *this;
}
};
Integer fa[int(N*6.7)][19];
void dfs(int v) {
dfn[v]=++dfn_id;
lst[dfn_id]=v;
for(int i=1;i<19;i++) fa[v][i]=fa[fa[v][i-1]][i-1];
size[v]=1;
for(int i=h[v];i;i=s[i].next) {
int to=s[i].to;
if(to==fa[v][0]) continue ;
fa[to][0]=v;
dep[to]=dep[v]+1;
dis[to]=dis[v]+1;
dfs(to);
size[v]+=size[to];
}
edf[v]=dfn_id;
}
int lca(int a,int b) {
if(dep[a]<dep[b]) swap(a,b);
for(int i=18;i>=0;i--)
if(fa[a][i]&&dep[fa[a][i]]>=dep[b])
a=fa[a][i];
if(a==b) return a;
for(int i=18;i>=0;i--)
if(fa[a][i]!=fa[b][i])
a=fa[a][i],b=fa[b][i];
return fa[a][0];
}
struct plant {ll a,b;}p[N];
struct query {ll a,b;}q[N];
ll key[N<<2];
int id_tot;
bool cmp(int a,int b) {return dfn[a]<dfn[b];}
map<ll,int>id;
vector<ll>tem;
ll pos[N];
int FLAG=1;
void build(ll d,int rt,ll FA) {
static vector<ll>a;
static ll st[N],top;
a.resize(tem.size());
for(int i=0;i<tem.size();i++) {
int x=Find_kth(::rt[dfn[rt]-1],::rt[edf[rt]],tem[i]-d,lx,rx);
a[i]=x;
}
sort(a.begin(),a.end(),cmp);
for(int i=0,x=a.size()-1;i<x;i++) {
a.push_back(lca(a[i],a[i+1]));
}
sort(a.begin(),a.end(),cmp);
int cc=unique(a.begin(),a.end())-a.begin();
for(int i=0;i<cc;i++) {
pos[a[i]]=d+query_sum(::rt[dfn[rt]-1],::rt[edf[rt]],lx,rx,a[i]);
}
for(int i=0;i<cc;i++) {
if(id.find(pos[a[i]])==id.end()) id[pos[a[i]]]=++id_tot;
}
for(int i=0;i<cc;i++) pos[a[i]]=id[pos[a[i]]];
fa[pos[a[0]]][0]=FA;
dep[pos[a[0]]]=dep[FA]+1;
dis[pos[a[0]]]=dis[FA]+1;
st[top=1]=a[0];
for(int i=1;i<cc;i++) {
while(edf[st[top]]<dfn[a[i]]) top--;
dep[pos[a[i]]]=dep[pos[st[top]]]+1;
dis[pos[a[i]]]=dis[pos[st[top]]]+dis[a[i]]-dis[st[top]];
fa[pos[a[i]]][0]=pos[st[top]];
st[++top]=a[i];
}
for(int i=0;i<cc;i++) {
int now=pos[a[i]];
for(int j=1;j<19;j++) fa[now][j]=fa[fa[now][j-1]][j-1];
}
}
int main() {
n=Get(),m=Get(),Q=Get();
lx=1,rx=n;
int a,b;
for(int i=1;i<n;i++) {
a=Get(),b=Get();
add(a,b),add(b,a);
}
dfs(1);
for(int i=1;i<=n;i++) {
rt[i]=rt[i-1];
Insert(rt[i],rt[i],lx,rx,lst[i]);
}
for(int i=1;i<=n;i++) id[i]=i;
ll SUM=n;
for(int i=1;i<=m;i++) {
p[i].a=Get(),p[i].b=Get();
key[++key[0]]=SUM+query_sum(rt[dfn[p[i].a]-1],rt[edf[p[i].a]],lx,rx,p[i].a);
if(p[i].b>n) key[++key[0]]=(p[i].b);
SUM+=size[p[i].a];
}
for(int i=1;i<=Q;i++) {
q[i].a=Get(),q[i].b=Get();
if(q[i].a>n) key[++key[0]]=(q[i].a);
if(q[i].b>n) key[++key[0]]=(q[i].b);
}
sort(key+1,key+1+key[0]);
id_tot=n;
int cc=unique(key+1,key+1+key[0])-key-1;
int tag=1;
SUM=n;
for(int i=1;i<=m;i++) {
tem.clear();
ll pre=SUM;
SUM+=size[p[i].a];
while(tag<=cc&&key[tag]<=SUM) tem.push_back(key[tag]),tag++;
build(pre,p[i].a,id[p[i].b]);
}
for(int i=1;i<=Q;i++) {
int a=id[q[i].a],b=id[q[i].b];
int f=lca(a,b);
cout<<dis[a]+dis[b]-2*dis[f]<<"\n";
}
return 0;
}