bzoj4381: [POI2015]Odwiedziny
这题搞了我一下午……因为一些傻X的问题……
对于步长大于sqrt(n)的询问,我们可以直接暴力求解
然后,我们可以事先预处理出d[u][step]表示u往上跳,每次跳step步,直到跳到不能跳为止,所获得的分数,其中step<=K
那么对于步长小于sqrt(n)的询问,我们直接查表然后一系列运算即可
各种细节自己yy吧
复杂度应该是O(nsqrt(n)logn)
#include <iostream> #include <cstdio> #include <cmath> #include <cstring> #include <cstdlib> #include <algorithm> #define N 50004 #define M 100005 using namespace std; inline int read(){ int ret=0;char ch=getchar(); while (ch<'0'||ch>'9') ch=getchar(); while ('0'<=ch&&ch<='9'){ ret=ret*10-48+ch; ch=getchar(); } return ret; } struct edge{ int adj,next; edge(){} edge(int _adj,int _next):adj(_adj),next(_next){} } e[M]; int n,g[N],m; void AddEdge(int u,int v){ e[++m]=edge(v,g[u]);g[u]=m; e[++m]=edge(u,g[v]);g[v]=m; } const int K=500; int a[N],f[N][23],d[N][K+5]; int fa[N],deep[N]; void dfs(int u){ deep[u]=deep[fa[u]]+1; int v=fa[u]; for (int i=1;i<=K;++i,v=fa[v]) d[u][i]=d[v][i]+a[u]; for (int i=g[u];i;i=e[i].next){ int v=e[i].adj; if (v==fa[u]) continue; fa[v]=u; dfs(v); } } void precompute(){ fa[1]=fa[0]=0;deep[0]=0; memset(d[0],0,sizeof(d[0])); dfs(1); for (int i=1;i<=n;++i) f[i][0]=fa[i]; memset(f[0],0,sizeof(f[0])); for (int k=1;k<=20;++k) for (int i=1;i<=n;++i) f[i][k]=f[f[i][k-1]][k-1]; } int jump(int u,int step){ for (int k=0;k<=20;++k)if ((step&(1<<k))>0)u=f[u][k]; return u; } int qlca(int u,int v){ if (deep[u]<deep[v]) swap(u,v); u=jump(u,deep[u]-deep[v]); for (int k=20;k>=0;--k)if (f[u][k]!=f[v][k])u=f[u][k],v=f[v][k]; return u==v?u:fa[u]; } int s[N],t[N]; int main(){ n=read(); for (int i=1;i<=n;++i) a[i]=read(); memset(g,0,sizeof(g));m=1; for (int i=1;i<n;++i) AddEdge(read(),read()); precompute(); for (int i=1;i<=n;++i) t[i-1]=s[i]=read(); for (int i=1;i<n;++i){ int step=read(),lca=qlca(s[i],t[i]),u=s[i],v=t[i],res=0; if ((deep[u]+deep[v]-2*deep[lca])%step>0){ res+=a[v]; v=jump(v,(deep[u]+deep[v]-2*deep[lca])%step); } if (deep[lca]%step==deep[u]%step) res+=a[lca]; if (step<=K){ int top1=jump(lca,(deep[lca]%step-deep[u]%step+step)%step),top2=jump(lca,(deep[lca]%step-deep[v]%step+step)%step); res+=d[u][step]-d[top1][step]+d[v][step]-d[top2][step]; } else for (int j=0;j<2;++j) for (swap(u,v);deep[u]>deep[lca];u=jump(u,step)) res+=a[u]; printf("%d\n",res); } return 0; }