[POI2015]Odwiedziny
[POI2015]Odwiedziny
题目大意:
一棵\(n(n\le5\times10^4)\)个点的树,\(n\)次询问从一个点到另一个点的路径上,每次跳\(k\)个点,所经过的点权和。
思路:
分块思想。
当\(k\ge\sqrt n\)时,显然每次询问不会跳超过\(\sqrt n\)次,可以借助树链剖分在\(\mathcal O(\sqrt n)\)的时间内暴力完成询问。
当\(k<\sqrt n\)时,预处理从一个点出发,每次跳\(k\)格,跳到根结点的权值和。可以\(\mathcal O(\log n)\)求LCA,\(\mathcal O(1)\)回答。
时间复杂度\(\mathcal O(n\sqrt n)\)。
源代码:
#include<cmath>
#include<cstdio>
#include<cctype>
#include<vector>
inline int getint() {
register char ch;
while(!isdigit(ch=getchar()));
register int x=ch^'0';
while(isdigit(ch=getchar())) x=(((x<<2)+x)<<1)+(ch^'0');
return x;
}
const int N=50001,B=223;
int n,block,a[N],b[N],c[N];
std::vector<int> e[N];
inline void add_edge(const int &u,const int &v) {
e[u].push_back(v);
e[v].push_back(u);
}
int anc[N][B],sum[N][B],dep[N],top[N],son[N],size[N],dfn[N],id[N];
void dfs(const int &x,const int &par) {
size[x]=1;
anc[x][1]=par;
sum[x][1]=sum[par][1]+a[x];
dep[x]=dep[par]+1;
for(register int i=2;i<block;i++) {
anc[x][i]=anc[anc[x][i-1]][1];
sum[x][i]=sum[anc[x][i]][i]+a[x];
}
for(unsigned i=0;i<e[x].size();i++) {
const int &y=e[x][i];
if(y==par) continue;
dfs(y,x);
size[x]+=size[y];
if(size[y]>size[son[x]]) {
son[x]=y;
}
}
}
void dfs(const int &x) {
dfn[x]=++dfn[0];
id[dfn[x]]=x;
top[x]=x==son[anc[x][1]]?top[anc[x][1]]:x;
if(son[x]) dfs(son[x]);
for(unsigned i=0;i<e[x].size();i++) {
const int &y=e[x][i];
if(y==anc[x][1]||y==son[x]) continue;
dfs(y);
}
}
inline int lca(int x,int y) {
while(top[x]!=top[y]) {
if(dep[top[x]]<dep[top[y]]) std::swap(x,y);
x=anc[top[x]][1];
}
if(dep[x]<dep[y]) std::swap(x,y);
return y;
}
inline int father(int x,int k) {
if(k>=dep[x]) return 0;
while(k>=dep[x]-dep[top[x]]+1) {
k-=dep[x]-dep[top[x]]+1;
x=anc[top[x]][1];
}
return id[dfn[x]-k];
}
inline int calc(int x,int y,const int &k) {
if(dep[x]<=dep[y]) return 0;
int ret=0;
if(k<block) {
while(y&&(dep[x]-dep[y])%k) y=anc[y][1];
ret=sum[x][k]-sum[y][k];
} else {
while(dep[x]>dep[y]) {
ret+=a[x];
x=father(x,k);
}
}
return ret;
}
inline int query(int x,int y,const int &k) {
const int z=lca(x,y),dis=dep[x]+dep[y]-dep[z]*2;
int ret=calc(x,z,k);
if(dis%k) {
ret+=a[y];
y=father(y,dis%k);
}
ret+=calc(y,anc[z][1],k);
return ret;
}
int main() {
block=sqrt(n=getint());
for(register int i=1;i<=n;i++) a[i]=getint();
for(register int i=1;i<n;i++) {
add_edge(getint(),getint());
}
dfs(1,0);
dfs(1);
for(register int i=1;i<=n;i++) b[i]=getint();
for(register int i=1;i<n;i++) {
printf("%d\n",query(b[i],b[i+1],getint()));
}
return 0;
}