[AT2268] [agc008_f] Black Radius
题目链接
AtCoder:https://agc008.contest.atcoder.jp/tasks/agc008_f
洛谷:https://www.luogu.org/problemnew/show/AT2268
Solution
首先假设所有点都是黑的。
设\(f(i,d)\)表示\(i\)节点扩展\(k\)步的点集,那么答案就是本质不同的点集个数。
我们考虑一个很巧妙的计数方法:每种点集都在\(d\)最小时被算一次,那么二元组一定要满足这样的性质:
- 首先我们硬点全集不选,答案最后加一。
- 对于\((x,d)\),我们要求所有于\(x\)相邻的点\(y\)都不存在\(f(x,d)=f(y,d-1)\)。
那么我们可以发现每个点都有一个选取上界,这个\(d\)满足以下性质:
- \(d\in [0,dis_x-1]\),其中\(dis_x\)表示离\(x\)最远点的距离。
- \(d\in [0,dis2_v+1]\),其中\(v\)为\(x\)的儿子,\(dis2_v\)表示\(x\)不经过\(v\)的\(dis\)最大值。
这个画个图就可以知道。
那么如果有一些点不是黑的,我们考虑给这些点定个下界,下界就是以\(x\)为根\(x\)的儿子的子树中含有黑点的子树的\(dis_1\)的最小值,这样就可以保证这种方案可以被一个黑点产生。
然后\(\rm tree\ dp\)实现就好了,复杂度\(O(n)\)。
#include<bits/stdc++.h>
using namespace std;
void read(int &x) {
x=0;int f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
void print(int x) {
if(x<0) putchar('-'),x=-x;
if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}
#define lf double
#define ll long long
#define pii pair<int,int >
#define vec vector<int >
#define pb push_back
#define mp make_pair
#define fr first
#define sc second
const int maxn = 5e5+10;
const int inf = 1e9;
const lf eps = 1e-8;
char s[maxn];
int sz[maxn],d1[maxn],d2[maxn],d3[maxn],d4[maxn],n,head[maxn],tot,f[maxn];
struct edge{int to,nxt;}e[maxn<<1];
void add(int u,int v) {e[++tot]=(edge){v,head[u]},head[u]=tot;}
void ins(int u,int v) {add(u,v),add(v,u);}
void dfs(int x,int fa) {
sz[x]=s[x]-'0',f[x]=fa;d3[x]=1e9;
for(int v,i=head[x];i;i=e[i].nxt)
if((v=e[i].to)!=fa) {
dfs(v,x),sz[x]+=sz[v];
d1[x]=max(d1[x],d1[v]+1);
if(sz[v]) d3[x]=min(d3[x],d1[e[i].to]+1);
}
}
void dfs2(int x,int fa) {
int fr=0,sc=0;if(fa) d4[x]=d2[x]-1;
for(int v,i=head[x];i;i=e[i].nxt) {
if((v=e[i].to)==fa) continue;
if(d1[v]+1>=fr) sc=fr,fr=d1[v]+1;
else if(d1[v]+1>sc) sc=d1[v]+1;
}
for(int v,i=head[x];i;i=e[i].nxt) {
if((v=e[i].to)==fa) continue;
if(d1[v]+1==fr) d2[v]=max(d2[x],sc)+1;
else d2[v]=max(d2[x],fr)+1;
dfs2(e[i].to,x);
}
}
int main() {
read(n);for(int i=1,x,y;i<n;i++) read(x),read(y),ins(x,y);
scanf("%s",s+1);dfs(1,0),dfs2(1,0);
ll ans=0;int mx,mn;
for(int x=1;x<=n;x++) {
mx=max(d1[x],d2[x])-1;
if(s[x]=='0') mn=min(d3[x],sz[1]==sz[x]?(int)1e9:d2[x]);else mn=0;
for(int i=head[x];i;i=e[i].nxt)
if(e[i].to==f[x]) mx=min(mx,d1[x]+1);
else mx=min(mx,d4[e[i].to]+1);
if(mx>=mn) ans+=(ll)mx-mn+1;
}printf("%lld\n",ans+1ll);
return 0;
}