CF1988D (思维 + 树形DP)

题意

有一棵包含 n 个结点的树。编号为 i(1in) 的结点上有一个攻击力为 ai 的怪物。你要跟怪物对战若干回合,直至将它们全部杀死。
每一回合,所有存活着的怪物会先对你进行一次攻击,你损失的生命值是所有存活着的怪物的攻击力之和;然后,你选择若干结点 (u1u2u3um),满足:u1u2u3um 两两不在同一条边上,将 u1u2u3um 位置上的怪物杀死。
若选择最优的方案,在所有怪物被杀死之后,你最少损失多少生命值。

数据范围:1n3×1051ai1012

题解

首先,2 个回合一定可以把所有怪物打死,但是,显然这种做法不总是最优的。

考虑 树形DP,设 dp[U][i] 表示将以 U 为根的子树上的怪物全部杀死并且第 i 回合杀死结点 U 位置上的怪物 之后,最少损失多少生命值。

结论 :最优策略下的总回合数一定小于等于 logn+1

证明:对于任意一个结点 U,假设它的邻点为 V1V2V3Vm,击杀它们的回合数分别为:t1t2t3tm,那么在回合 MEX(t1t2tm) 击杀编号为 U 的结点上的怪物显然是最优的 (每拖一个回合,就多被攻击一次)。
如果存在某一个怪兽,我们在第 i 回合击杀它是最优的,那么我们假设树上结点的个数至少fi
有:f1=1i2,fi=1+j=1i1fj
fi=1+fi1+fi1

fi=1+4×fi2

fi=1+2i1

因为需要满足 fi=1+2i1n,所以,ilog(n1)+1,证毕。


DP 的转移方程:dp[U][i]=i×aU+Vson(U)minj[1,logn+1]&&jidp[V][j]

答案即:minj=1logn+1dp[1][j]

时间复杂度为 O(nlog2n)

点击查看代码
#include <bits/stdc++.h>

using i64 = long long;

constexpr i64 inf = 1E18;
constexpr int N = 3E5 + 5;

i64 a[N], dp[N][30], ans;
std::vector <int> adj[N];
int n, u, v;

template <typename T>
inline void read(T &f) {
    f = 0; T fu = 1; char c = getchar();
    while (c < '0' || c > '9') { if (c == '-') { fu = -1; } c = getchar(); }
    while (c >= '0' && c <= '9') { f = (f << 3) + (f << 1) + (c & 15); c = getchar(); }
    f *= fu;
}
 
template <typename T>
void print(T x) {
    if (x < 0) putchar('-'), x = -x;
    if (x < 10) putchar(x + 48);
    else print(x / 10), putchar(x % 10 + 48);
}
 
template <typename T>
void print(T x, char t) {
    print(x); putchar(t);
}

void solve() {
  read(n);
  for (int i = 1; i <= n; i++) {
    read(a[i]);
    adj[i].clear();
  }

  for (int i = 1; i < n; i++) {
    read(u); read(v);
    adj[u].push_back(v);
    adj[v].push_back(u);
  }

  const int M = std::log2(n) + 1;
  for (int i = 1; i <= n; i++) {
    for (int j = 1; j <= M; j++) {
      dp[i][j] = 0;
    }
  }
  auto dfs = [&](auto self, int u, int p) -> void {
    for (int i = 1; i <= M; i++) {
      dp[u][i] = a[u] * i;
    }

    for (auto v : adj[u]) {
      if (v == p) {
        continue;
      }
      self(self, v, u);

      for (int i = 1; i <= M; i++) {
        i64 min = inf;
        for (int j = 1; j <= M; j++) {
          if (j != i) {
            min = std::min(min, dp[v][j]);
          }
        }
        dp[u][i] += min;
      }
    }
  };
  dfs(dfs, 1, 0);

  ans = inf;
  for (int i = 1; i <= M; i++) {
    ans = std::min(ans, dp[1][i]);
  }
  print(ans, '\n');
}

int main() {
  int T;
  read(T);
  while (T--) {
    solve();
  }
  return 0;
}

Bonus

在更新 dp[U][i] 之前,可以记录 dp[V][j] 的前缀、后缀最小值,这样可以实现 O(logn) 的更新。总的时间复杂度可以优化到 O(nlogn)

点击查看代码
#include <bits/stdc++.h>

using i64 = long long;

constexpr i64 inf = 1E18;
constexpr int N = 3E5 + 5;

i64 a[N], dp[N][30], pre[30], suf[30], ans, min;
std::vector <int> adj[N];
int n, u, v;

template <typename T>
inline void read(T &f) {
    f = 0; T fu = 1; char c = getchar();
    while (c < '0' || c > '9') { if (c == '-') { fu = -1; } c = getchar(); }
    while (c >= '0' && c <= '9') { f = (f << 3) + (f << 1) + (c & 15); c = getchar(); }
    f *= fu;
}
 
template <typename T>
void print(T x) {
    if (x < 0) putchar('-'), x = -x;
    if (x < 10) putchar(x + 48);
    else print(x / 10), putchar(x % 10 + 48);
}
 
template <typename T>
void print(T x, char t) {
    print(x); putchar(t);
}

void solve() {
  read(n);
  for (int i = 1; i <= n; i++) {
    read(a[i]);
    adj[i].clear();
  }

  for (int i = 1; i < n; i++) {
    read(u); read(v);
    adj[u].push_back(v);
    adj[v].push_back(u);
  }

  const int M = std::log2(n) + 1;
  for (int i = 1; i <= n; i++) {
    for (int j = 1; j <= M; j++) {
      dp[i][j] = 0;
    }
  }
  auto dfs = [&](auto self, int u, int p) -> void {
    for (int i = 1; i <= M; i++) {
      dp[u][i] = a[u] * i;
    }

    for (auto v : adj[u]) {
      if (v == p) {
        continue;
      }
      self(self, v, u);

      pre[1] = dp[v][1]; suf[M] = dp[v][M];
      for (int i = 2; i <= M; i++) {
        pre[i] = std::min(pre[i - 1], dp[v][i]);
      }
      for (int i = M - 1; i >= 1; i--) {
        suf[i] = std::min(suf[i + 1], dp[v][i]);
      }
      for (int i = 1; i <= M; i++) {
        min = inf;
        if (i - 1 >= 1) {
          min = std::min(min, pre[i - 1]);
        }
        if (i + 1 <= M) {
          min = std::min(min, suf[i + 1]);
        }
        dp[u][i] += min;
      }
    }
  };
  dfs(dfs, 1, 0);

  ans = inf;
  for (int i = 1; i <= M; i++) {
    ans = std::min(ans, dp[1][i]);
  }
  print(ans, '\n');
}

int main() {
  int T;
  read(T);
  while (T--) {
    solve();
  }
  return 0;
}
posted @   yanhy-orz  阅读(49)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】
点击右上角即可分享
微信分享提示