[USACO12FEB]附近的牛Nearby Cows
分享一种奇妙的解法;(也不清楚有没有人写过emm)
首先,对于树上的任意一个点x,它k布以内能够到达的点分为3种情况:
1:当前节点正下方的点,如它的儿子和它儿子的儿子;
我们将x恰好j步所能到达的此类节点的权值和记作f1[x][j];
2:当前节点正上方的点,如它的父亲和它父亲的父亲;
我们将x恰好j步所能到达的此类节点的权值和记作f2[x][j];
3:当前节点的旁系血亲,如我父亲的另一个儿子,我祖先的另一个后代等;
我们将x恰好j步所能到达的此类节点的权值和也记在f2[x][j]里;(理由等下再做讨论)
显然,第一种和第二种情况可以DFS和BFS直接求;
那么,现在问题就变成了如何求它的旁系血亲,即同x不在一条链上的对答案有贡献的点;
我们发现,如当前节点能通过它的正上方的节点,也就是它的直系祖先(我们先记作fa[x]),到达它的旁系血亲;fa[x]也必然能够到达这位旁系血亲;那我们就可以在进行对tip2的求解时顺路传递下来了;(这就是我们将tip2与tip3统一用f2记录的原因);
既然这样,那么我们就只要求x父亲的其它子树中满足条件的节点了,即x的兄弟节点。
我们用y表示另一位兄弟节点,因为x要先到x,y的父亲fa[x],再走到y,这里要消耗2步的距离,那状态转移方程就顺理成章的出来了:
f2[x][j]+=f2[y][j-2];
如此这般,在最后输出答案时,我们只需要计算
$sum_{i}=\sum_{j=0,j\le k}f1[i][j]+f2[i][j]$
然后sum还需要-=f1[i][0];因为我们给f1[i][0],f2[i][0]都赋值为c[i];在这里被计算了两次,需要去重;
把每一个i的sum依次输出就行了;
值得注意的一点是某些毒瘤数据,直接算会被卡爆,我们可以加入一些优化;我的做法是处理出一个maxdep;
如果$2maxdep-1<=k$此时每个点必然能到达所以的点因为你从一个深度最大点到根再到另一个深度最大点需要的距离(2maxdqp-1)都在k的范围以内,那么所以的点显然可以互相到达;
以下是AC代码
#include<bits/stdc++.h>
#define ll long long
#define N 100001
using namespace std;
inline void read(int &x){
int datta=0;char chchc=getchar();bool okoko=0;
while(chchc<'0'||chchc>'9'){if(chchc=='-')okoko=1;chchc=getchar();}
while(chchc>='0'&&chchc<='9'){datta=datta*10+chchc-'0';chchc=getchar();}
x=okoko?-datta:datta;
}
int n,k,a,b,cnt,head[N],nxt[N<<1],to[N<<1];
int root=1,f1[N][21],f2[N][21],fa[N],nn[N];
int mmx,ssum,dep[N];
void link(int a,int b){
to[++cnt]=b,nxt[cnt]=head[a],head[a]=cnt;nn[a]++;
}
void get_ans1(int x,int last){
dep[x]=dep[last]+1;
if(dep[x]>mmx)mmx=dep[x];
for(int i=head[x];i;i=nxt[i]){
if(to[i]!=last){
get_ans1(to[i],x);
for(int j=1;j<=k;j++)
f1[x][j]+=f1[to[i]][j-1];
}
}
}
void get_ans2(int x,int last){
int p[nn[x]];int tot=0;
for(int i=head[x];i;i=nxt[i]){
if(to[i]!=last){
p[++tot]=to[i];
get_ans2(to[i],x);
}
}
for(int i=1;i<=tot;i++){
for(int j=1;j<=tot;j++){
if(i!=j){
for(int l=2;l<=k;l++)
f2[p[i]][l]+=f1[p[j]][l-2];
}
}
}
}
void get_ans3(int x,int last){
for(int i=head[x];i;i=nxt[i]){
if(to[i]!=last){
for(int j=1;j<=k;j++)
f2[to[i]][j]+=f2[x][j-1];
get_ans3(to[i],x);
}
}
}
int main(){
read(n),read(k);
for(int i=1;i<n;i++){
read(a),read(b);
link(a,b),link(b,a);
}
for(int i=1;i<=n;i++){
read(f1[i][0]);
f2[i][0]=f1[i][0];
ssum+=f1[i][0];
}
get_ans1(root,0);
if((mmx<<1)-1<=k){
for(int i=1;i<=n;i++){
printf("%d\n",ssum);
}return 0;
}
get_ans2(root,0);
get_ans3(root,0);
for(int i=1;i<=n;i++){
int sum=0;
for(int j=0;j<=k;j++)
sum+=f1[i][j]+f2[i][j];
sum-=f1[i][0];
printf("%d\n",sum);
}
return 0;
}