bzoj4381: [POI2015]Odwiedziny

这题搞了我一下午……因为一些傻X的问题……

对于步长大于sqrt(n)的询问,我们可以直接暴力求解

然后,我们可以事先预处理出d[u][step]表示u往上跳,每次跳step步,直到跳到不能跳为止,所获得的分数,其中step<=K

那么对于步长小于sqrt(n)的询问,我们直接查表然后一系列运算即可

各种细节自己yy吧

复杂度应该是O(nsqrt(n)logn)

#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#define N 50004
#define M 100005

using namespace std;
inline int read(){
	int ret=0;char ch=getchar();
	while (ch<'0'||ch>'9') ch=getchar();
	while ('0'<=ch&&ch<='9'){
		ret=ret*10-48+ch;
		ch=getchar();
	}
	return ret;
}

struct edge{
	int adj,next;
	edge(){}
	edge(int _adj,int _next):adj(_adj),next(_next){}
} e[M];
int n,g[N],m;
void AddEdge(int u,int v){
	e[++m]=edge(v,g[u]);g[u]=m;
	e[++m]=edge(u,g[v]);g[v]=m;
}

const int K=500;
int a[N],f[N][23],d[N][K+5];
int fa[N],deep[N];
void dfs(int u){
	deep[u]=deep[fa[u]]+1;
	int v=fa[u];
	for (int i=1;i<=K;++i,v=fa[v])
		d[u][i]=d[v][i]+a[u];
	for (int i=g[u];i;i=e[i].next){
		int v=e[i].adj;
		if (v==fa[u]) continue;
		fa[v]=u;
		dfs(v);
	}
}

void precompute(){
	fa[1]=fa[0]=0;deep[0]=0;
	memset(d[0],0,sizeof(d[0]));
	dfs(1);
	for (int i=1;i<=n;++i) f[i][0]=fa[i];
	memset(f[0],0,sizeof(f[0]));
	for (int k=1;k<=20;++k)
		for (int i=1;i<=n;++i)
			f[i][k]=f[f[i][k-1]][k-1];
}

int jump(int u,int step){
	for (int k=0;k<=20;++k)if ((step&(1<<k))>0)u=f[u][k];
	return u;
}
int qlca(int u,int v){
	if (deep[u]<deep[v]) swap(u,v);
	u=jump(u,deep[u]-deep[v]);
	for (int k=20;k>=0;--k)if (f[u][k]!=f[v][k])u=f[u][k],v=f[v][k];
	return u==v?u:fa[u];
}

int s[N],t[N];
int main(){
	n=read();
	for (int i=1;i<=n;++i) a[i]=read();
	memset(g,0,sizeof(g));m=1;
	for (int i=1;i<n;++i) AddEdge(read(),read());
	precompute();
	for (int i=1;i<=n;++i) t[i-1]=s[i]=read();
	for (int i=1;i<n;++i){
		int step=read(),lca=qlca(s[i],t[i]),u=s[i],v=t[i],res=0;
		if ((deep[u]+deep[v]-2*deep[lca])%step>0){
			res+=a[v];
			v=jump(v,(deep[u]+deep[v]-2*deep[lca])%step);
		}
		if (deep[lca]%step==deep[u]%step) res+=a[lca];
		if (step<=K){
			int top1=jump(lca,(deep[lca]%step-deep[u]%step+step)%step),top2=jump(lca,(deep[lca]%step-deep[v]%step+step)%step);
			res+=d[u][step]-d[top1][step]+d[v][step]-d[top2][step];
		}
		else for (int j=0;j<2;++j)
			for (swap(u,v);deep[u]>deep[lca];u=jump(u,step))
				res+=a[u];
		printf("%d\n",res);
	}
	return 0;
}

  

posted @ 2016-02-27 17:33  wangyurzee  阅读(224)  评论(0编辑  收藏  举报