BZOJ4381 : [POI2015]Odwiedziny

设$lim=\sqrt{n}$。

 

若$k<lim$,预处理出:

$F[i][x]$:$x$往上走$i$步到达的点。

$S[i][x]$:$x$不断往上走$i$步经过的点的和。

直接$O(1)$查询即可。

 

若$k\geq lim$:

查询时用树链剖分划分为$O(\log n)$条重链,在每条重链上暴力往上跳。

时间复杂度$O(\log n+\sqrt{n})$。


总时间复杂度$O(n\sqrt{n})$。

 

#include<cstdio>
const int N=50010,M=224;
int n,lim,i,j,x,y,a[N],b[N],g[N],v[N<<1],nxt[N<<1],ed;
int d[N],f[N],size[N],son[N],top[N],loc[N],seq[N],dfn;
int F[M][N],S[M][N];
inline void swap(int&a,int&b){int c=a;a=b;b=c;}
inline void add(int x,int y){v[++ed]=y;nxt[ed]=g[x];g[x]=ed;}
void dfs(int x){
  S[1][x]=S[1][F[1][x]=f[x]]+a[x];
  for(int i=2;i<lim;i++)S[i][x]=S[i][F[i][x]=f[F[i-1][x]]]+a[x];
  d[x]=d[f[x]]+1,size[x]=1;
  for(int i=g[x];i;i=nxt[i])if(v[i]!=f[x]){
    f[v[i]]=x,dfs(v[i]),size[x]+=size[v[i]];
    if(size[v[i]]>size[son[x]])son[x]=v[i];
  }
}
void dfs2(int x,int y){
  top[x]=y,seq[loc[x]=++dfn]=x;
  if(son[x])dfs2(son[x],y);
  for(int i=g[x];i;i=nxt[i])if(v[i]!=f[x]&&v[i]!=son[x])dfs2(v[i],v[i]);
}
inline int lca(int x,int y){
  while(top[x]!=top[y]){
    if(d[top[x]]<d[top[y]])swap(x,y);
    x=f[top[x]];
  }
  return d[x]<d[y]?x:y;
}
inline int father(int x,int y){
  while(loc[x]-loc[top[x]]<y){
    y-=loc[x]-loc[top[x]]+1;
    x=f[top[x]];
  }
  return seq[loc[x]-y];
}
inline int up(int x,int y,int k){
  y=father(x,d[x]-d[y]-(d[x]-d[y])%k);
  if(k<lim)return S[k][x]-S[k][y]+a[y];
  int t=0,i;
  while(top[x]!=top[y]){
    for(i=loc[x];i>=loc[top[x]];i-=k)t+=a[seq[i]];
    x=father(x,loc[x]-i);
  }
  for(i=loc[x];i>=loc[y];i-=k)t+=a[seq[i]];
  return t;
}
inline int query(int x,int y,int k){
  int z=lca(x,y),t=up(x,z,k);
  if((d[x]+d[y]-d[z]*2)%k)t+=a[y];
  if(z==y)return t;
  if((d[x]-d[z])%k==0)t-=a[z];
  int tmp=((d[z]-d[x])%k+k)%k;
  if(d[y]-d[z]-tmp<0)return t;
  return t+up(father(y,(d[y]-d[z]-tmp)%k),z,k);
}
int main(){
  scanf("%d",&n);
  while(lim*lim<n)lim++;
  for(i=1;i<=n;i++)scanf("%d",&a[i]);
  for(i=1;i<n;i++)scanf("%d%d",&x,&y),add(x,y),add(y,x);
  dfs(1),dfs2(1,1);
  for(i=1;i<=n;i++)scanf("%d",&b[i]);
  for(i=1;i<n;i++)scanf("%d",&x),printf("%d\n",query(b[i],b[i+1],x));
  return 0;
}

  

posted @ 2016-03-13 02:28  Claris  阅读(645)  评论(3编辑  收藏  举报