[AGC008F] Black Radius(树形dp)
神题啊!!
Description
给你一棵有N个节点的树,节点编号为1到N,所有边的长度都为1
"全"对某些节点情有独钟,这些他喜欢的节点的信息会以一个长度为N的字符串s的形式给到你,具体一点就是对于1<=i<=N,si=1表示"全"喜欢节点i,为0表示"全"不喜欢节点i
一开始的时候,所有的节点都是白色的,"全"会进行以下操作恰好一次:
选择一个他喜欢的节点v和一个非负整数d,然后将所有与节点v距离不超过d的节点全部涂黑
问进行操作之后,有多少种不同的涂色情况?两种情况不同当且仅当两种情况存在一个节点i的颜色不同
Input
第一行一个正整数N
接下来N−1行每行两个正整数xi,yi表示xi到yi有一条边
最后一行一个字符串s
Output
输出不同染色情况的数量
题解:
设\(f(i,d)\),为把距 \(i\) 点小于等于 \(d\) 染成黑色的集合。
不过这样染色的时候会有重复的方案,我们考虑下何时会重复。
当 $f(i,d_1) $ 和 \(f(j,d2)\),重复时,当且仅当 \(i\),\(j\) 之间存在一点 \(k\),使得 \(f(k,d1-dist(i,k))\),\(dist(i,j)\) 表示 \(i\),\(j\) 的树上距离。
这个很显然吧……
所以对于某个\(f(i,d)\),如果有\(f(k,d1-dist(i,k))=f(i,d)\),那么 \(f(i,d)\) 就会被重复算。
对于每一个\(i\),我们只用考虑更他相邻的点和他的重复情况就可以了。(也很显然吧)
那我们再考虑一下,何时 \(f(i,d)=f(k,d-1)\)。
还是很显然,如果以 \(k\) 为根,\(i\) 的子树全部被染成黑色的话,他们就相等。
所以可以得知,\(d\) 的上界为 \(i\) 到他子树中的最远点的距离。
那我们用树形dp,\(O(n)\) 算出,然后暴力枚举每个 \(i\),算 \(d\) 的上界。
那下界呢? 如果 \(i\) 点是特殊点,无疑是0,那如果不是呢?
如果不是,那么若存在一个特殊节点j满足方案\((i,d)\)中 \(j\) 所在子树内所有节点均被染成黑色,\((i,d)\) 就是一个合法的染色方案。故我们只需要求出从 \(i\) 出发的至少经过 1 个特殊节点到达 \(j\) 子树中的最远节点的距离的最小值,就是可行的 \(d\) 的最小值。
知道每个 \(i\),\(d\) 的上下界,那就算吧!
CODE:
#include<iostream>
#include<cstdio>
using namespace std;
int d1[200005],d2[200005];
int d3[200005],d4[200005];
//d1:子树中的最远距离
//d2:非子树中的最远距离
//d3:这个节点到子树中至少经过1个特殊节点到达子树中的最远节点最小距离
//d4:从他父亲出发,不经过这棵子树的最远路径
int n,x,y,siz[200005],col[200005],fa[200005];
long long ans=0;
int tot=0,h[200005];
struct Edge{
int x,next;
}e[400005];
char ch[200005];
inline void add_edge(int x,int y){
e[++tot].x=y;
e[tot].next=h[x],h[x]=tot;
}
void dfs1(int x,int father){
siz[x]=col[x],fa[x]=father;
d3[x]=col[x]?0:1e9;
for(int i=h[x];i;i=e[i].next){
if(e[i].x==father)continue;
dfs1(e[i].x,x);
siz[x]+=siz[e[i].x];
d1[x]=max(d1[x],d1[e[i].x]+1);
if(siz[e[i].x])
d3[x]=min(d3[x],d1[e[i].x]+1);
}
}
void dfs2(int x,int father){
if(father)d4[x]=d2[x]-1;
int maxn=0,sec=0;
for(int i=h[x];i;i=e[i].next){
if(e[i].x==father)continue;
if(d1[e[i].x]+1>maxn)
sec=maxn,maxn=d1[e[i].x]+1;
else sec=max(sec,d1[e[i].x]+1);
}
for(int i=h[x];i;i=e[i].next){
if(e[i].x==father)continue;
if(d1[e[i].x]+1==maxn)
d2[e[i].x]=max(d2[x],sec)+1;
else d2[e[i].x]=max(d2[x],maxn)+1;
dfs2(e[i].x,x);
}
}
int main(){
scanf("%d",&n);
for(int i=1;i<n;i++){
scanf("%d%d",&x,&y);
add_edge(x,y);
add_edge(y,x);
}
scanf("%s",ch+1);
for(int i=1;ch[i];i++)
col[i]=ch[i]-'0';
dfs1(1,0),dfs2(1,0);
for(int i=1;i<=n;i++){
int minv,maxv;
minv=min(d3[i],siz[1]==siz[i]?(int)1e9:d2[i]);
maxv=max(d1[i],d2[i])-1;
for(int j=h[i];j;j=e[j].next){
if(e[j].x==fa[i])maxv=min(maxv,d1[i]+1);
else maxv=min(maxv,d4[e[j].x]+1);
}
if(maxv>=minv)ans+=1LL*maxv-minv+1;
}
printf("%lld",ans+1);
}