[AT2268] [agc008_f] Black Radius

题目链接

AtCoder:https://agc008.contest.atcoder.jp/tasks/agc008_f

洛谷:https://www.luogu.org/problemnew/show/AT2268

Solution

首先假设所有点都是黑的。

\(f(i,d)\)表示\(i\)节点扩展\(k\)步的点集,那么答案就是本质不同的点集个数。

我们考虑一个很巧妙的计数方法:每种点集都在\(d\)最小时被算一次,那么二元组一定要满足这样的性质:

  • 首先我们硬点全集不选,答案最后加一。
  • 对于\((x,d)\),我们要求所有于\(x\)相邻的点\(y\)都不存在\(f(x,d)=f(y,d-1)\)

那么我们可以发现每个点都有一个选取上界,这个\(d\)满足以下性质:

  • \(d\in [0,dis_x-1]\),其中\(dis_x\)表示离\(x\)最远点的距离。
  • \(d\in [0,dis2_v+1]\),其中\(v\)\(x\)的儿子,\(dis2_v\)表示\(x\)不经过\(v\)\(dis\)最大值。

这个画个图就可以知道。

那么如果有一些点不是黑的,我们考虑给这些点定个下界,下界就是以\(x\)为根\(x\)的儿子的子树中含有黑点的子树的\(dis_1\)的最小值,这样就可以保证这种方案可以被一个黑点产生。

然后\(\rm tree\ dp\)实现就好了,复杂度\(O(n)\)

#include<bits/stdc++.h>
using namespace std;

void read(int &x) {
    x=0;int f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
    for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}

void print(int x) {
    if(x<0) putchar('-'),x=-x;
    if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}

#define lf double
#define ll long long 

#define pii pair<int,int >
#define vec vector<int >

#define pb push_back
#define mp make_pair
#define fr first
#define sc second

const int maxn = 5e5+10;
const int inf = 1e9;
const lf eps = 1e-8;

char s[maxn];
int sz[maxn],d1[maxn],d2[maxn],d3[maxn],d4[maxn],n,head[maxn],tot,f[maxn];
struct edge{int to,nxt;}e[maxn<<1];

void add(int u,int v) {e[++tot]=(edge){v,head[u]},head[u]=tot;}
void ins(int u,int v) {add(u,v),add(v,u);}

void dfs(int x,int fa) {
    sz[x]=s[x]-'0',f[x]=fa;d3[x]=1e9;
    for(int v,i=head[x];i;i=e[i].nxt)
        if((v=e[i].to)!=fa) {
            dfs(v,x),sz[x]+=sz[v];
            d1[x]=max(d1[x],d1[v]+1);
            if(sz[v]) d3[x]=min(d3[x],d1[e[i].to]+1);
        }
}

void dfs2(int x,int fa) {
    int fr=0,sc=0;if(fa) d4[x]=d2[x]-1;
    for(int v,i=head[x];i;i=e[i].nxt) {
        if((v=e[i].to)==fa) continue;
        if(d1[v]+1>=fr) sc=fr,fr=d1[v]+1;
        else if(d1[v]+1>sc) sc=d1[v]+1;
    }
    for(int v,i=head[x];i;i=e[i].nxt) {
        if((v=e[i].to)==fa) continue;
        if(d1[v]+1==fr) d2[v]=max(d2[x],sc)+1;
        else d2[v]=max(d2[x],fr)+1;
        dfs2(e[i].to,x);
    }
}

int main() {
    read(n);for(int i=1,x,y;i<n;i++) read(x),read(y),ins(x,y);
    scanf("%s",s+1);dfs(1,0),dfs2(1,0);
    ll ans=0;int mx,mn;
    for(int x=1;x<=n;x++) {
        mx=max(d1[x],d2[x])-1;
        if(s[x]=='0') mn=min(d3[x],sz[1]==sz[x]?(int)1e9:d2[x]);else mn=0;
        for(int i=head[x];i;i=e[i].nxt)
            if(e[i].to==f[x]) mx=min(mx,d1[x]+1);
            else mx=min(mx,d4[e[i].to]+1);
        if(mx>=mn) ans+=(ll)mx-mn+1;
    }printf("%lld\n",ans+1ll);
    return 0;
}
posted @ 2019-04-19 14:01  Hyscere  阅读(227)  评论(0编辑  收藏  举报