AGC008F Black Radius

Statement

Snuke 君有一棵 \(n\) 个节点的全白的树,其中有一些节点他喜欢,有一些节点他不喜欢。他会选择一个他喜欢的节点 \(x\),然后选择一个距离 \(d\),然后将所有与 \(x\) 距离不超过 \(d\) 的节点都染成黑色,问最后有多少种可能的染色后状态。

Solution

什么神仙题!!

问题要统计的可以转化为一个涂黑的连通块,把每一个合法连通块用另外一种方法表示出来。

将一个连通块表示为 \((x,d)\) 的二元组,其中 \(x\) 还可以是一条边,表示以 \(x\) 为直径中心,半径为 \(d\) 的连通块,这个是可以证明没啥问题的。

现在只要知道了中心和半径就可以确定一个连通块,可以考虑枚举中心 \(x\),然后求出对应的 \(d\) 有多少种取值就好了。

分类讨论一下 \(x\) 是边还是点。

假如 \(x\) 是连接 \((u,v)\) 的边,且 \(maxdep_u \leq maxdep_v\),那么我们显然需要填满子树 \(u\) 之后才会蔓延到 \(v\),否则 \(x\) 这条边不会是中心,于是如果 \(u\) 中有喜欢的点,那么最终会有 \(1\) 的贡献,否则没有贡献。

接下来考虑 \(x\) 是一个点的情况,还需要再分类一下这是个喜欢的点还是不喜欢的点。

在这种情况下直观感受 \(d\) 会是一段区间(。

如果 \(x\) 是一个喜欢的点,那么 \(d\) 的下界显然为 \(0\),可以只涂自己一个点。

上界考虑一下 \(x\)\(maxdep\) 最大的 \(y\),也就是 \(maxdep_y\) 发现不太行,因为假如这样的话直径中心应该会偏向 \(y\) 方向,所以减小一下,取个次大值发现就很对。

如果 \(x\) 不是一个喜欢的点,上界和这个无关,那么上界不变,只用思考下界怎么变。

下界应该是有喜欢的点的子树中最小的 \(maxdep\),因为为了涂到 \(x\),让 \(x\) 作为一个直径中点,所以最小的那个喜欢点 \(u\) 必须被涂到,但是这样 \(u\) 的下面部分也会被涂到,要让 \(x\) 作为直径中点而不偏移,于是将下面的也涂掉。

然后对于实现,每次都是让一个点或者一条边作为根,所以随便做一个换根 dp 就好了,然后实现有些细节,可以看代码,我写的好长啊。

// 德丽莎你好可爱德丽莎你好可爱德丽莎你好可爱德丽莎你好可爱德丽莎你好可爱
// 德丽莎的可爱在于德丽莎很可爱,德丽莎为什么很可爱呢,这是因为德丽莎很可爱!
// 没有力量的理想是戏言,没有理想的力量是空虚
#include <bits/stdc++.h>
#define LL long long
#define int long long
using namespace std;
char ibuf[1 << 15], *p1, *p2;
#define getchar() (p1 == p2 && (p2 = (p1 = ibuf) + fread(ibuf, 1, 1 << 15, stdin), p1==p2) ? EOF : *p1++)
inline int read() {
  char ch = getchar();  int x = 0, f = 1;
  while (ch < '0' || ch > '9')  {  if (ch == '-')  f = -1;  ch = getchar();  }
  while (ch >= '0' && ch <= '9')  x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
  return x * f;
}
void print(LL x) {
  if (x > 9)  print(x / 10);
  putchar(x % 10 + '0');
}
template<class T> bool chkmin(T &a, T b) { return a > b ? (a = b, true) : false; }
template<class T> bool chkmax(T &a, T b) { return a < b ? (a = b, true) : false; }
#define rep(i, l, r) for (int i = (l); i <= (r); i++)
#define repd(i, l, r) for (int i = (l); i >= (r); i--)
#define REP(i, l, r)  for (int i = (l); i < (r); i++)
const int N = 2e6;
int n, num, nex[N], first[N], v[N], a[N];
char s[N];
void add(int from,int to) {
  nex[++num] = first[from];  first[from] = num;  v[num] = to; 
}
int dep[N], mxdep1[N];
struct node {
  int mx, mx2, id;
  void pre() {  mx = mx2 = id = -1;  }
  void upd(int val,int now) {
    if (val > mx) {  mx2 = mx;  mx = val;  id = now;  }
    else if (val > mx2)  mx2 = val;
  }
} pt[N];
void dfs1(int x,int fa) {
  dep[x] = dep[fa] + 1;
  for (int i = first[x]; i; i = nex[i]) {
    int to = v[i];  if (to == fa)  continue;  dfs1(to, x);
  }
}
int f[N], g[N], mxdep2[N];
void dfs2(int x,int fa) {
  int son = 0;
  mxdep1[x] = 0;  g[x] = f[x] = a[x];
  pt[x].pre();  pt[x].upd(0, x);
  for (int i = first[x]; i; i = nex[i]) {
    int to = v[i];  if (to == fa)  continue;
    dfs2(to, x);  mxdep1[x] = max(mxdep1[x], mxdep1[to] + 1);  f[x] += f[to];
    pt[x].upd(mxdep1[to] + 1, to);
  }
}
void dfs3(int x,int fa) {
  int mx = 0, mx2 = 0, id = 0, id2 = 0;
  if (fa) {
    g[x] = g[fa] + f[fa] - f[x];
    if (x == pt[fa].id) {
      if (pt[fa].mx2 != -1) {
        mxdep2[x] = max(mxdep2[x], pt[fa].mx2 + 1);
        pt[x].upd(pt[fa].mx2 + 1, fa);
      }
    }  else {
      if (pt[fa].mx != -1) {
        mxdep2[x] = max(mxdep2[x], pt[fa].mx + 1);
        pt[x].upd(pt[fa].mx + 1, fa);
      }
    }
  }
  for (int i = first[x]; i; i = nex[i]) {
    int to = v[i];  if (to == fa)  continue;  dfs3(to, x);
  }
}
int ans, sum;
pair <int,int> e[N];
int solveedge(int i) {
  int x = e[i].first, y = e[i].second;
  if (dep[x] > dep[y])  swap(x, y);
  int now;
  if (pt[x].id == y)  now = pt[x].mx2;
  else  now = pt[x].mx;
  if (now == mxdep1[y]) {
    if (sum - f[y] || f[y])  return 1;
    return 0;
  }
  if (now < mxdep1[y]) {
    if (sum - f[y])  return 1;
    return 0;
  }
  if (now > mxdep1[y]) {
    if (f[y])  return 1;
    return 0;
  }
  return 1;
}
int solvever(int x) {
  int son = 0;
  for (int i = first[x]; i; i = nex[i])  son++;
  if (son <= 1)  return a[x];
  int u, d = n + 1, mx = -1, mx2 = -1;
  if (a[x])  d = 0;
  mx = mxdep2[x];
  if (mxdep2[x]  && sum - f[x])  d = min(d, mxdep2[x]);
  for (int i = first[x]; i; i = nex[i]) {
    int to = v[i];
    if (dep[to] < dep[x])  continue;
    if (f[to])  d = min(d, mxdep1[to] + 1);
    if (mxdep1[to] + 1 > mx) {
      mx2 = mx,  mx = mxdep1[to] + 1;
    }  else  if (mxdep1[to] + 1 > mx2) {
      mx2 = mxdep1[to] + 1;
    }
  }
  u = mx2;
  if (d == n + 1 || u < d)  return 0;
  return u - d + 1;
}
void solve() {
  cin >> n;
  rep (i, 2, n) {
    int x, y;  cin >> x >> y;
    add(x, y), add(y, x);
    e[i] = {x, y};
  }
  cin >> (s + 1);
  rep (i, 1, n)  a[i] = s[i] - '0', sum += a[i];
  dfs1(1, 0);  dfs2(1, 0);
  dfs3(1, 0);
  int ans = 0;
  // for (int i = 1; i <= n; i++) {
  //   cout << mxdep1[i] << " " << mxdep2[i] << " " << f[i] << " " << g[i] << " " << pt[i].mx << " " << pt[i].mx2 << " " << pt[i].id << "\n";
  // }
  // cout << "---\n";
  rep (i, 2, n)  ans += solveedge(i);
  // cout << ans << "\n" ;
  rep (i, 1, n)  ans += solvever(i);
  cout << ans << "\n";
}
signed main () {
#ifdef LOCAL_DEFINE
  freopen("1.in", "r", stdin);
  freopen("1.ans", "w", stdout);
#endif
  int T = 1;  while (T--)  solve();
#ifdef LOCAL_DEFINE
  cerr << "Time elapsed: " << 1.0 * clock() / CLOCKS_PER_SEC << " s.\n";
#endif
  return 0;
}
posted @ 2022-08-17 10:58  Pitiless0514  阅读(35)  评论(0编辑  收藏  举报