树分治杂谈

树分治杂谈

​ 首先在这块内容之前,先介绍一些基础的树论知识。

​ 设 T 为任意 n 个结点的有根树,其根结点为 root,集合 V,E 分别为 T 的结点集合、边集合。定义 Su 表示结点 u 子树内所有结点的集合,depu 表示结点 u 的深度 (根结点的深度为 1) 。记结点 u 的儿子集合为 chu,其父亲为 fau,记 degu=|chu|

​ 设 ancu 为结点 u 在树 T 上的祖先集合。定义 lca(u,v) 为树 T 上结点 u,v 的最近公共祖先,即 lca(u,v) 是满足 xancuancvdepx 最大的结点 x。定义 path(u,v) 为树 T 上结点 u,v 的简单路径上的点集合。记 dis(u,v)=|path(u,v)|1。特殊地,有 path(u,u)={u},dis(u,u)=0

​ 则以下结论成立:

  1. |Sroot|=n,uVdegu=n1
  2. lca(u,v)=vuSvvancu
  3. depu=vV[uSv]=|ancu|
  4. uV|Su|=uVdepu=uV|ancu|=O(n2)
  5. uV(2|Su|+x,ychu[xy]|Sx||Sy|)=n2+n
  6. dis(u,v)=depu+depv2deplca(u,v)

​ 下面我们来介绍一下树的重心。定义结点 g 为无根树 T(V,E) 的重心,其满足当 Tg 为根时表达式 maxvchg|Sv| 的值最小。由此可知任意无根树 T 的重心可能不止有一个。树的重心满足以下性质:

  1. 任意无根树 T(V,E) 的重心最多有两个 g,gV,并且存在 e(g,g)E
  2. 当无根树 T(V,E) 以重心 g 为根时,2maxvchg|Sv||V|
  3. 设无根树 T(V,E) 重心为 g,g 定义函数 F(x)=vVdis(x,v) 则有 F(g)=F(g)=minvV{F(v)}
  4. 在无根树 T(V,E) 上添加或删除一个叶结点,其重心最多只移动一条边的距离
  5. 若无根树 T(V,E) 有边权,那么重心 g,g 的位置与边权无关
  6. 设无根树 T(V,E) 以任意结点 x 为根,则重心 g 满足 2|Sg||V|depg 最大

静态链分治

​ 链分治的最基本的思想就是继承。无论是哪一种链分治,都是在儿子中选一个特殊的儿子作为"重儿子",将重儿子子树内的信息以较低的时间复杂度 O(1)O(degupolylogdegu) 直接继承到父亲上,以此减小运算次数。因此一个做法可以用链分治优化的前提是它所维护的东西有可继承性

​ 静态链分治包括"dsu on tree"和"长链剖分"两种,它们针对不同种类暴力进行不同的优化。

dsu on tree

​ 考虑到某种树上问题,直接做需要枚举每一个结点 u,然后在子树 u 内"搞事情",其中"搞事情"这一部分运算次数的数量级为 O(|Su|),于是暴力做法的总运算次数为 O(uV|Su|)=O(n2)

​ 那么这种方法就有可能可以用 dsu on tree 优化。我们的出发点是尽可能的减少运算次数,但我们又只能继承"重儿子",因此可以将"重儿子"选为运算次数最多的那个儿子。设 hu 为点 u 的"重儿子",则应该有 |Shu|=maxvchu{|Sv|},即 hu=hsonu。因此总运算次数的数量级为 O(uV(|Su||Shsonu|))=O(nlogn)

维护子树信息

​ 大部分维护子树信息的运算次数数量级都是 O(|Su|),这使得 dsu on tree 可以轻松优化这些做法。当 dsu on tree 用于统计子树信息时,还有一个结论:

结论 在 dsu on tree 的过程中,每个结点上的信息最多会被统计 logn 次。

证明 考察结点 x,由重链剖分的结论知路径 rootx 上最多有 logn 个轻儿子;而子树 u 内的所有结点被统计一次当且仅当 u 是其父亲的轻儿子,因此结点 x 被统计的次数的数量级为 O(logn)

这也说明了 dsu on tree 的总运算次数数量级为 O(nlogn)

​ 大致代码如下:

void sack(Node u, bool keep = false) {
  for (Node v: u.Childs)
    if (v != u.hson) sack(v, false);
  if (u.hson.exist)
    sack(u.hson, true);
  for (Node v: u.Childs)
    if (v != u.hson)
      for (Node x: v.SubTree)
        add_to_sack(x);
  add_to_sack(u);
  u.answer = calc();
  if (keep == false) clear();
} // sack ( dsu on tree )

例题 1 Lomsat gelral

​ 很经典的一道题。暴力做法是枚举每个结点 u,然后枚举 u 子树内的所有结点 x,将 ax 加入到一个容器里面,然后最后在容器里查询答案。注意到容器内的信息是可继承的,那么直接运用 dsu on tree 就行了,时间复杂度 O(nlogn)

参考代码

#include <bits/stdc++.h>
using namespace std;

static constexpr int Maxn = 1e5 + 5;

int n;
int64_t ans[Maxn];
int a[Maxn];
vector<int> g[Maxn];
int sz[Maxn], son[Maxn], dep[Maxn];
int dfn[Maxn], idfn[Maxn], ed[Maxn], dfn_index;
void sack_init(int u, int fa, int depth = 1) {
  sz[u] = 1, son[u] = 0, dep[u] = depth;
  idfn[dfn[u] = ++dfn_index] = u;
  for (const int &v: g[u]) if (v != fa) {
    sack_init(v, u, depth + 1), sz[u] += sz[v];
    if (son[u] == 0 || sz[v] > sz[son[u]]) son[u] = v;
  }
  ed[u] = dfn_index;
} // sack_init
int buc[Maxn], mxc, mxC;
int64_t sc;
void sack(int u, int fa, bool keep = false) {
  for (const int &v: g[u]) if (v != fa)
    if (v != son[u]) sack(v, u, false);
  if (son[u] != 0) sack(son[u], u, true);
  for (const int &v: g[u]) if (v != fa)
    if (v != son[u]) {
      for (int i = dfn[v]; i <= ed[v]; ++i) {
        ++buc[a[idfn[i]]];
        if (mxc == 0 || buc[a[idfn[i]]] > mxC)
          mxc = a[idfn[i]], sc = a[idfn[i]], mxC = buc[mxc];
        else if (buc[a[idfn[i]]] == buc[mxc])
          sc += a[idfn[i]];
      }
    }
  ++buc[a[u]];
  if (mxc == 0 || buc[a[u]] > mxC)
    mxc = a[u], sc = a[u], mxC = buc[mxc];
  else if (buc[a[u]] == buc[mxc])
    sc += a[u];
  ans[u] = sc;
  if (!keep) {
    for (int i = dfn[u]; i <= ed[u]; ++i)
      buc[a[idfn[i]]]--;
    mxc = 0, sc = 0, mxC = 0;
  }
} // sack

int main(void) {

  scanf("%d", &n);
  for (int i = 1; i <= n; ++i)
    scanf("%d", &a[i]);
  for (int i = 1; i < n; ++i) {
    int u, v; scanf("%d%d", &u, &v);
    g[u].push_back(v);
    g[v].push_back(u);
  }
  sack_init(1, 0);
  sack(1, 0);
  for (int i = 1; i <= n; ++i)
    printf("%lld%c", ans[i], " \n"[i == n]);

  exit(EXIT_SUCCESS);
} // main

例题 2 Tree Requests

​ 对每一个深度开一个桶存放当前这一深度压位后的值。如果压位后的值为 02k,则说明可以组成回文串。于是我们就可以得到一个暴力做法:离线询问,枚举每个结点 u,接着枚举 u 子树内的所有结点 x,将 ax 的贡献加入到 depx 桶内,然后处理挂在结点 u 下的询问。运用 dsu on tree 优化后,时间复杂度为 O(nlogn)

参考代码

#include <bits/stdc++.h>
using namespace std;

static constexpr int Maxn = 5e5 + 5;
static constexpr int Maxk = 26;
bool legal[1 << Maxk];

int n, m, ans[Maxn];
int par[Maxn];
vector<int> g[Maxn];
char str[Maxn];
vector<pair<int, int>> q[Maxn];
int dep[Maxn], sz[Maxn], son[Maxn];
int dfn[Maxn], idfn[Maxn], ed[Maxn], dfn_index;
void sack_init(int u, int fa, int depth = 1) {
  dep[u] = depth, sz[u] = 1, son[u] = 0;
  idfn[dfn[u] = ++dfn_index] = u;
  for (const int &v: g[u]) if (v != fa) {
    sack_init(v, u, depth + 1), sz[u] += sz[v];
    if (son[u] == 0 || sz[v] > sz[son[u]]) son[u] = v;
  }
  ed[u] = dfn_index;
} // sack_init
int cnt[Maxn];
void sack(int u, int fa, bool keep = false) {
  for (const int &v: g[u]) if (v != fa)
    if (v != son[u]) sack(v, u, false);
  if (son[u]) sack(son[u], u, true);
  for (const int &v: g[u]) if (v != fa)
    if (v != son[u]) {
      for (int i = dfn[v]; i <= ed[v]; ++i) {
        int x = idfn[i];
        cnt[dep[x]] ^= (1 << str[x] - 'a');
      }
    }
  cnt[dep[u]] ^= (1 << str[u] - 'a');
  for (const auto &[d, id]: q[u]) {
    ans[id] = legal[cnt[d]];
  }
  if (!keep) {
    for (int i = dfn[u]; i <= ed[u]; ++i)
      cnt[dep[idfn[i]]] ^= (1 << str[idfn[i]] - 'a');
  }
} // sack

int main(void) {
  memset(legal, false, sizeof(legal));
  legal[0] = true;
  for (int i = 0; i < Maxk; ++i)
    legal[(1 << i)] = true;
  scanf("%d%d", &n, &m);
  for (int i = 2; i <= n; ++i)
    scanf("%d", &par[i]), g[par[i]].push_back(i);
  scanf("%s", str + 1);
  for (int i = 1, u, d; i <= m; ++i)
    scanf("%d%d", &u, &d), q[u].push_back({d, i});
  sack_init(1, 0);
  sack(1, 0);
  for (int i = 1; i <= m; ++i)
    printf("%s\n", ans[i] ? "Yes" : "No");
  exit(EXIT_SUCCESS);
} // main

例题 3 Just Kingdom

​ 一道神仙题。对 dsu on tree 的更深层次理解有很大的帮助。

​ 不难发现,子树 u 的需求量为 su=vSuwv,而题目中所描述的发钱过程则类似于灌水的过程,其中子结点 v 对应的水管容量为 sv。现在考虑若要让子结点 v 得到 sv,结点 u 至少需要得到多少水。将结点 u 的所有子结点 ( 除了 v ) 的 s 值按从低到高排序后得到序列 t,则结点 u 至少要得到的水量为 tisvti+(1+ti>sv1)sv。于是得到了一个暴力做法:对于每一个结点往上跳父结点,直到根结点。这样做的时间复杂度是 O(i=1ndepi) 的。

​ 注意到 O(i=1ndepi)=O(i=1n|Si|)。于是想到可以使用 dsu on tree 优化。但是上述的暴力是无法直接用 dsu on tree 优化的。于是我们考虑一个更贴切的方法。

​ 在每一个结点处存储子树中所有结点目前的答案 ( 即尚未完全加工的答案 ) ,这个结点内要做的事儿就是将所有结点目前的答案加工并上传到其父亲处。从子结点转移到父结点的过程就是合并多个子树。

​ 考虑使用 dsu on tree 优化上述过程。使用平衡树来存储当前结点内的所有答案,那么轻儿子就可以暴力插入一个个值。最难的问题就是如何将重儿子的信息直接继承到父结点上,且复杂度必须为 O(polydegu)。我们设子结点的 su 序列排序后得到 tu,那么对于同一个 u,值域区间 tixti+1 内所有 x 的加工方式都是一样的,都有 ax+b 的形式。也就是说要对若干个区间 [l,r],将这些区间中的值都作用上一次函数。又观察到在操作完后序列的单调性不变,于是可以用平衡树区间加乘懒标记维护。每个结点继承重儿子的时间复杂度为 O(degulogdegu)。由于暴力插入轻儿子时单次时间复杂度为 O(logn),这种操作的次数数量级为 O(nlogn),于是总时间复杂度 O(nlog2n)

​ 一个坑点:在继承重儿子时操作顺序应该是值域从大往小进行操作。

参考代码
#include <bits/stdc++.h>
using namespace std;
static constexpr int Maxn = 3e5 + 5;
static constexpr int64_t inf = 0x3f3f3f3f3f3f3f3f;
__attribute__((__always_inline__)) inline std::uint64_t randu64(void) {
  static std::uint64_t seed = std::chrono::steady_clock::now().time_since_epoch().count();
  seed += 0xa0761d6478bd642full; __uint128_t t = (__uint128_t)(seed ^ 0xe7037ed1a0b428dbull) * seed;
  return (t >> 64) ^ t;
} // randu64
int n, par[Maxn];
int64_t w[Maxn], ans[Maxn];
vector<int> g[Maxn];
int sz[Maxn], son[Maxn];
int64_t sw[Maxn];
struct node {
  int ls, rs, size, id;
  uint64_t fix;
  int64_t val, add, mul;
  node() = default;
  node(int64_t val, int id) : ls(0), rs(0), size(1), fix(randu64()), val(val), id(id), add(0), mul(1) { }
} tr[Maxn];
int root, tot;
void initialize(void) { tr[root = tot = 0] = node(); }
int newnode(int64_t x, int id) { tr[++tot] = node(x, id); return tot; }
__attribute__((__always_inline__)) inline
void pushup(int p) { tr[p].size = tr[tr[p].ls].size + tr[tr[p].rs].size + 1; }
__attribute__((__always_inline__)) inline
void apply(int p, int64_t add, int64_t mul) {
  if (!p) return ;
  tr[p].val = tr[p].val * mul + add;
  tr[p].mul *= mul;
  tr[p].add = tr[p].add * mul + add;
} // apply
__attribute__((__always_inline__)) inline
void pushdown(int p) {
  apply(tr[p].ls, tr[p].add, tr[p].mul);
  apply(tr[p].rs, tr[p].add, tr[p].mul);
  tr[p].add = 0, tr[p].mul = 1LL;
} // pushdown
int join(int l, int r) {
  if (!l || !r) return l | r;
  pushdown(l), pushdown(r);
  if (tr[l].fix < tr[r].fix) {
    tr[l].rs = join(tr[l].rs, r);
    pushup(l); return l;
  } else {
    tr[r].ls = join(l, tr[r].ls);
    pushup(r); return r;
  }
} // join
void split(int p, int64_t key, int &l, int &r) {
  if (!p) return (l = r = 0), void(); pushdown(p);
  if (tr[p].val <= key) l = p, split(tr[p].rs, key, tr[l].rs, r), pushup(l);
  else r = p, split(tr[p].ls, key, l, tr[r].ls), pushup(r);
} // split
void insert(int64_t x, int i) {
  int A, B; split(root, x - 1, A, B);
  root = join(A, join(newnode(x, i), B));
} // insert
void modify(int64_t wl, int64_t wr, int64_t add, int64_t mul) {
  int A, B, C; split(root, wl - 1, A, B), split(B, wr, B, C);
  apply(B, add, mul); root = join(A, join(B, C));
} // modify
vector<pair<int64_t, int>> all;
void treap_dfs(int p) {
  if (!p) return; pushdown(p);
  treap_dfs(tr[p].ls);
  all.push_back({tr[p].val, tr[p].id});
  treap_dfs(tr[p].rs);
} // treap_dfs
vector<pair<int64_t, int>> traverse(int rt) {
  all.clear(); treap_dfs(rt); return all;
} // traverse
vector<pair<int64_t, int>> t[Maxn];
void sack_init(int u) {
  sz[u] = 1, sw[u] = w[u], son[u] = 0;
  for (const int &v: g[u]) {
    sack_init(v), sz[u] += sz[v], sw[u] += sw[v];
    if (son[u] == 0 || sz[v] > sz[son[u]]) son[u] = v;
  }
} // sack_init
void sack(int u) {
  for (const int &v: g[u]) if (v != son[u])
    sack(v), t[v] = traverse(root), initialize();
  vector<int64_t> ws, wall;
  for (const int &v: g[u]) wall.push_back(sw[v]);
  for (const int &v: g[u]) if (v != son[u]) ws.push_back(sw[v]);
  sort(ws.begin(), ws.end());
  sort(wall.begin(), wall.end());
  vector<int64_t> wsp = ws, wallp = wall;
  for (int i = 1; i < (int)ws.size(); ++i) wsp[i] += wsp[i - 1];
  for (int i = 1; i < (int)wall.size(); ++i) wallp[i] += wallp[i - 1];
  if (son[u] != 0) {
    sack(son[u]);
    if (!ws.empty()) {
      modify(ws.back(), inf, wsp.back(), 0 + 1);
      for (int i = (int)ws.size() - 2; i >= 0; --i)
        modify(ws[i], ws[i + 1] - 1, wsp[i], (int)ws.size() - i);
      modify(-inf, ws[0] - 1, 0LL, (int)ws.size() + 1);
    }
  }
  for (const int &v: g[u]) if (v != son[u]) {
    for (const auto &[x, i]: t[v]) {
      auto it = upper_bound(wall.begin(), wall.end(), x);
      int index = it - wall.begin() - 1;
      int64_t sum = (index >= 0 ? wallp[index] : 0), mul = (int)wall.size() - index;
      (it == wall.end() || sw[v] < *it) ? (sum -= sw[v]) : (mul--);
      insert(sum + mul * x, i);
    }
    t[v].clear();
  }
  insert(sw[u], u);
} // sack
int main(void) {
  scanf("%d", &n);
  for (int i = 1; i <= n; ++i) scanf("%d%lld", &par[i], &w[i]);
  for (int i = 1; i <= n; ++i) g[par[i]].push_back(i);
  sack_init(0); initialize(); sack(0);
  auto res = traverse(root);
  for (const auto &[x, i]: res) ans[i] = x;
  for (int i = 1; i <= n; ++i) printf("%lld\n", ans[i]);
  exit(EXIT_SUCCESS);
} // main

例题 4 [NOIP2016 提高组] 天天爱跑步

​ 设有一条路径 (u,v),令 z=lca(u,v)。设 x 为路径 (u,v) 上的一个结点,那么 (u,v)x 有贡献当且仅当 dis(u,x)=tx。也就是

{depu=tx+depx,(xpath(u,z))depu2depz=txdepx.(xpath(v,z))

注意上式的左边都是关于 x 的常量,而右边都是关于 x 的变量。

​ 于是我们现在需要做的是:对于结点 x 求出有多少条在集合 U 中的路径 (u,v) 经过点 x 且满足上述关系式。树上差分之后开两个桶分别维护 depudepu2depz 就变成了子树统计问题,使用 dsu on tree 即可。时间复杂度为 O(nlogn)

参考代码

#include <bits/stdc++.h>
using namespace std;
static constexpr int Maxn = 3e5 + 5;
int n, m, w[Maxn], ans[Maxn];
vector<int> g[Maxn];
int par[Maxn], dep[Maxn], sz[Maxn], hson[Maxn];
int top[Maxn], dfn[Maxn], idfn[Maxn], dn;
void dfs1(int u, int fa, int depth) {
  par[u] = fa, dep[u] = depth; sz[u] = 1, hson[u] = 0;
  for (const int &v: g[u]) if (v != par[u]) {
    dfs1(v, u, depth + 1), sz[u] += sz[v];
    if (hson[u] == 0 || sz[v] > sz[hson[u]]) hson[u] = v;
  }
} // dfs1
void dfs2(int u, int topv) {
  top[u] = topv, idfn[dfn[u] = ++dn] = u;
  if (hson[u] != 0) dfs2(hson[u], topv);
  for (const int &v: g[u]) if (v != par[u])
    if (v != hson[u]) dfs2(v, v);
} // dfs2
inline int get_lca(int u, int v) {
  for (; top[u] != top[v]; v = par[top[v]])
    if (dep[top[u]] > dep[top[v]]) swap(u, v);
  return dep[u] < dep[v] ? u : v;
} // get_lca
vector<int> au[Maxn], ad[Maxn], du[Maxn], dd[Maxn];
static constexpr int BASE = 3e5;
int buc1[Maxn * 3], buc2[Maxn * 3];
inline void Add(int x, int W) {
  for (int v: au[x]) buc1[v] += W;
  for (int v: du[x]) buc1[v] -= W;
  for (int v: ad[x]) buc2[v] += W;
  for (int v: dd[x]) buc2[v] -= W;
} // Add
inline int Ask(int u) {
  return buc1[w[u] + dep[u] + BASE] + buc2[w[u] - dep[u] + BASE];
} // Ask
void sack(int u, bool keep) {
  for (const int &v: g[u]) if (v != par[u] && v != hson[u])
    sack(v, false);
  if (hson[u] != 0) sack(hson[u], true);
  Add(u, 1);
  for (const int &v: g[u]) if (v != par[u] && v != hson[u])
    for (int i = dfn[v]; i < dfn[v] + sz[v]; ++i)
      Add(idfn[i], 1);
  ans[u] = Ask(u);
  if (!keep)
    for (int i = dfn[u]; i < dfn[u] + sz[u]; ++i)
      Add(idfn[i], -1);
} // sack
int main(void) {
  scanf("%d%d", &n, &m);
  for (int i = 2; i <= n; ++i) {
    int u, v;
    scanf("%d%d", &u, &v);
    g[u].push_back(v);
    g[v].push_back(u);
  }
  for (int i = 1; i <= n; ++i)
    scanf("%d", &w[i]);
  dfs1(1, 0, 1), dn = 0, dfs2(1, 1);
  for (int i = 1; i <= m; ++i) {
    int u, v;
    scanf("%d%d", &u, &v);
    int z = get_lca(u, v);
    au[u].push_back(dep[u] + BASE);
    ad[v].push_back(dep[u] - 2 * dep[z] + BASE);
    du[z].push_back(dep[u] + BASE);
    if (par[z] != 0) dd[par[z]].push_back(dep[u] - 2 * dep[z] + BASE);
  }
  sack(1, false);
  for (int i = 1; i <= n; ++i)
    printf("%d%c", ans[i], " \n"[i == n]);
  exit(EXIT_SUCCESS);
} // main

习题 1 Tree and Queries

习题 2 Blood Cousins

习题 3 Blood Cousins Return

习题 4 [Vani有约会]雨天的尾巴

习题 5 NFLSOJ 12521. 山花

题目描述

4.2 Description

​ 今天又是去采花的好日子啊~~

Quelle 站在山顶,发现今日的花树们,在不甚平坦的山上,焕发出了别样的光彩。

​ 简单来说,它们组成了一棵以 1 为根的树(QuQ……

​ 每棵花树上有若干朵花,具体的,在编号为 i 的花树上有 ai 朵花。

​ 山花自然是越多越好,但 Quelle 却做不到雨露均沾……

​ 于是 Quelle 决定从某一个节点 u 开始对其子树中与 u 距离小于 K 的节点代表的花树进行采 摘。特别的,节点 u 代表的花树也会被采摘。

​ 依旧受限于精力,Quelle 并不会亲自去采摘而是使用 Extremely Strong 的工具进行采摘。

​ 我们定义一个工具的能力为 cQuelle 会采摘的山树集合为 T

​ 那么 Quelle 能采摘到的山花数量:

fT=iTgcd(ai,c)

​ 现在对于给定的树和阀值 KQuelle 想要知道每一组询问的 fT

4.3 Input Format

​ 第一行,三个正整数 n,Q,K,代表花树的棵数,询问次数和阀值。

​ 接下来一行 n 个正整数,其中第 i 个数代表编号为 i 的花树的花的个数 ai

​ 接下来 n1 行,描述了花树们所形成的那棵树,每行两个正整数 u,v,代表编号为 uv 的花树直接相连。

​ 接下来 Q 行,每行描述了一次询问,包含两个正整数 x,c 代表这次 Quelle 决定从编号为 x 的花树开始采摘,这次工具的能力为 c

4.4 Output Format

​ 共 Q 行,每行一个整数 ans,满足 ansfi(mod998244353)

​ 其中 fi 为第 i 次询问的答案,即能采摘到的山花数量。

4.8 Limits

​ 对于 100% 的数据,n,Q,K105,ai,ci107

优化枚举树上路径

​ 形式化的就是说,给你一个有根树,让你对于每个结点 u 求所有路径 xy 的一个特定的值,其中路径 xy 满足:

  1. 路径 xy 经过点 u,且 x,ySu。即有 lca(x,y)=u
  2. 路径 xy 满足一个题目给定的条件 (或无条件)
  3. 路径 xy 的条件及贡献可以转化成 单元的。所谓的单元的函数,就是只有一个变量的函数。

普通的做法就是枚举路径 xy,然后更新其 LCA 的答案。枚举路径的数量级为 O(n2)

​ 接下来考虑用 dsu on tree 优化,但由于刚刚的做法并不涉及到子树有关的信息,因此不太好直接使用 dsu on tree。考虑换一种枚举顺序,先枚举路径 xy 的 LCA 点 z,然后用类似树形背包的转移顺序枚举 x,y ( 注意确保 lca(x,y)=z )。由于枚举的都是所有路径,这样暴力的数量级也是 O(n2) 。但这个暴力也不好用 dsu on tree 优化,因为在 z 子树枚举 x,y 的数量级最坏情况是 O(|Sz|2)

​ 由于条件及贡献都是单元的,可以将其拆分为 f(x)g(y)h(z),其中 h(z)z 确定的时候是常量,运算 , 是某种二元运算。于是我们就可以在枚举 y 的同时设置一个容器存放之前所有的 f(x) 满足 lca(x,y)=z。于是,这样枚举 x,y 的数量级就是严格 O(|Sz|) 的。由于容器里存放的 f(x) 是单元的,满足继承条件,因此这种枚举方法就可以用 dsu on tree 优化了。

​ 大致代码如下:

void sack(Node u, bool keep = false) {
  for (Node v: u.Childs)
    if (v != u.hson) sack(v, false);
  if (u.hson.exist)
    sack(u.hson, true);
  for (Node v: u.Childs)
    if (v != u.hson) {
      for (Node x: v.SubTree)
        updateAnswer(u.answer, ask(x));
      for (Node x: v.SubTree)
        add_to_sack(x);
    }
  updateAnswer(u.answer, ask(u));
  add_to_sack(u);
  if (keep == false) clear();
} // sack ( dsu on tree )

例题 1 Arpa’s letter-marked tree and Mehrdad’s Dokhtar-kosh paths

​ 设 mu 为结点 u 到根的路径上字符压位后的值,则树上的路径 uv 满足要求当且仅当 mumv=02k。直接 dsu on tree 优化即可。时间复杂度为 O(||nlogn)

参考代码

#include <bits/stdc++.h>
using namespace std;
template<typename _Tp> _Tp &min_eq(_Tp &x, const _Tp &y) { return x = min(x, y); }
template<typename _Tp> _Tp &max_eq(_Tp &x, const _Tp &y) { return x = max(x, y); }
static constexpr int Maxn = 5e5 + 5, Maxk = 22;
int n, par[Maxn], ans[Maxn];
vector<pair<int, int>> g[Maxn];
int sz[Maxn], son[Maxn], dep[Maxn];
int mask[Maxn];
void sack_init(int u, int fa, int depth = 1) {
  sz[u] = 1, son[u] = 0, dep[u] = depth;
  for (const auto &[v, w]: g[u]) if (v != fa) {
    mask[v] = (mask[u] ^ (1 << w));
    sack_init(v, u, depth + 1), sz[u] += sz[v];
    if (son[u] == 0 || sz[v] > sz[son[u]]) son[u] = v;
  }
} // sack_init
int buc[1 << Maxk];
template<typename F>
void dfs(int u, int fa, F f) {
  f(u);
  for (const auto &[v, w]: g[u]) if (v != fa)
    dfs(v, u, f);
} // dfs
void sack(int u, int fa, bool keep = true) {
  for (const auto &[v, w]: g[u]) if (v != fa)
    if (v != son[u]) sack(v, u, false);
  if (son[u]) sack(son[u], u, true);
  auto add = [&](int x) { max_eq(buc[mask[x]], dep[x]); };
  auto del = [&](int x) { buc[mask[x]] = 0; };
  auto qry = [&](int x) {
    if (buc[mask[x]])
      max_eq(ans[u], dep[x] + buc[mask[x]] - 2 * dep[u]);
    for (int i = 0; i < Maxk; ++i)
      if (buc[mask[x] ^ (1 << i)] != 0)
        max_eq(ans[u], dep[x] + buc[mask[x] ^ (1 << i)] - 2 * dep[u]);
  };
  for (const auto &[v, w]: g[u]) if (v != fa)
    if (v != son[u]) dfs(v, u, qry), dfs(v, u, add);
  qry(u), add(u);
  if (keep == false) dfs(u, fa, del);
} // sack
void dfs_ans(int u, int fa) {
  for (const auto &[v, w]: g[u]) if (v != fa)
    dfs_ans(v, u), max_eq(ans[u], ans[v]);
} // dfs_ans
int main(void) {
  scanf("%d", &n);
  for (int i = 2; i <= n; ++i) {
    char ch;
    scanf("\n%d %c", &par[i], &ch);
    g[par[i]].push_back({i, ch - 'a'});
    g[i].push_back({par[i], ch - 'a'});
  }
  sack_init(1, 0);
  sack(1, 0, true);
  dfs_ans(1, 0);
  for (int i = 1; i <= n; ++i)
    printf("%d%c", ans[i], " \n"[i == n]);
  exit(EXIT_SUCCESS);
} // main

例题 2 Sum of Prefix Sums

​ 定义 g(x,y) 为路径 xy 的"sum",且 f(x,y) 为路径 xy 的"sum of prefix sum"。下文中为了方便,定义 su=f(u,root),tu=f(root,u),wu=g(root,u)

​ 设有路径 uv,点 z=lca(u,v)。设 z1z 的父亲,z2z 的儿子且满足 z2 在路径 uz 上。则有

f(u,z2)=susz(wuwz)depzf(z,v)=tvtz1wz1(depvdepz1)g(u,z2)=wuwzf(u,v)=f(u,z2)+g(u,z2)(depvdepz1)+f(z,v)=[susz(wuwz)]+(wuwz)(depvdepz1) +tvtz1wz1(depvdepz1)=suszwudepz+wzdepz +wudepvwzdepvwudepz1+wzdepz1 +tvtz1wz1depv+wz1depz1

​ 注意到上式中各项都是单元的,因此可以考虑用 dsu on tree 优化。由于路径是有向的,需要做两次 dsu on tree:第一次枚举 u,第二次枚举 v。可以发现,无论是枚举 u 还是枚举 v,贡献式都可以表示为 kx+b+a 的形式,其中 x,a 是枚举的值,k,b 是扔到容器里的值。于是可以用李超树或动态凸包维护。总时间复杂度为 O(nlog2n)

参考代码

#include <bits/stdc++.h>
using namespace std;
template<typename _Tp> _Tp &min_eq(_Tp &x, const _Tp &y) { return x = min(x, y); }
template<typename _Tp> _Tp &max_eq(_Tp &x, const _Tp &y) { return x = max(x, y); }

static constexpr int Maxn = 1.5e5 + 5;

int n;
int64_t ans, a[Maxn];
vector<int> g[Maxn];
int dfn[Maxn], idfn[Maxn], ed[Maxn], dfn_index;
int sz[Maxn], son[Maxn], dep[Maxn];
int64_t s[Maxn], t[Maxn], X[Maxn];
void sack_init(int u, int fa, int depth = 1) {
  sz[u] = 1, son[u] = 0; dep[u] = depth;
  X[u] = X[fa] + a[u];
  s[u] = s[fa] + dep[u] * a[u];
  t[u] = t[fa] + X[u];
  idfn[dfn[u] = ++dfn_index] = u;
  for (const int &v: g[u]) if (v != fa) {
    sack_init(v, u, depth + 1), sz[u] += sz[v];
    if (son[u] == 0 || sz[v] > sz[son[u]]) son[u] = v;
  }
  ed[u] = dfn_index;
} // sack_init
// add a line f(x), query max{f(a)}
namespace HULL {
  using value_t = int64_t;
  static constexpr value_t null_value = std::numeric_limits<value_t>::min();
  bool flag;
  struct line {
    value_t k, b;
    mutable function<const line* ()> nxt;
    friend bool operator < (const line &a, const line &b) {
      if (!flag) return a.k < b.k;
      const line *s = a.nxt();
      if (!s) return false;
      return a.b - s->b < b.b * (s->k - a.k);
    }
  }; // line
  struct dynamic_hull : public multiset<line> {
    bool bad(iterator it) {
      if (it == this->end()) return false;
      auto nxt = next(it);
      if (it == this->begin()) {
        if (nxt == this->end()) return false;
        return it->k == nxt->k && it->b <= nxt->b;
      }
      auto prv = prev(it);
      if (nxt == this->end())
        return it->k == prv->k && it->b <= prv->b;
      return (prv->b - it->b) * (nxt->k - it->k) >= (it->b - nxt->b) * (it->k - prv->k);
    } // dynamic_hull::bad
    void add(value_t k, value_t b) {
      auto it = this->insert((line){k, b});
      it->nxt = [=]() { return next(it) == this->end() ? nullptr : &*next(it); };
      if (bad(it)) return this->erase(it), void();
      while (next(it) != this->end() && bad(next(it))) this->erase(next(it));
      while (it != this->begin() && bad(prev(it))) this->erase(prev(it));
    } // dynamic_hull::add
    value_t query(value_t x) {
      if (this->empty()) return null_value;
      flag = true;
      line l = *lower_bound((line){0, x});
      flag = false;
      return l.k * x + l.b;
    } // dynamic_hull::query
  }; // dynamic_hull
} // namespace HULL
using HULL::dynamic_hull;
dynamic_hull h;
void sack1(int u, int fa, bool keep = true) {
  for (const int &v: g[u]) if (v != fa)
    if (v != son[u]) sack1(v, u, false);
  if (son[u]) sack1(son[u], u, true);
  auto ask = [&](int x) {
    int64_t res = h.query(X[x] - X[u] - X[fa]);
    res += s[x] - s[u] - X[x] * dep[u] + X[u] * dep[u] - X[x] * dep[fa] + X[u] * dep[fa] - t[fa] + X[fa] * dep[fa];
    return res;
  };
  for (const int &v: g[u]) if (v != fa)
    if (v != son[u]) {
      for (int i = dfn[v]; i <= ed[v]; ++i) {
        int x = idfn[i];
        max_eq(ans, ask(x));
      }
      for (int i = dfn[v]; i <= ed[v]; ++i) {
        int x = idfn[i];
        h.add(dep[x], t[x]);
      }
    }
  h.add(dep[u], t[u]);
  max_eq(ans, ask(u));
  if (!keep) h.clear();
} // sack1
void sack2(int u, int fa, bool keep = true) {
  for (const int &v: g[u]) if (v != fa)
    if (v != son[u]) sack2(v, u, false);
  if (son[u]) sack2(son[u], u, true);
  auto ask = [&](int x) {
    int64_t res = h.query(dep[x] - dep[u] - dep[fa]);
    res += t[x] - t[fa] - X[fa] * dep[x] + X[fa] * dep[fa] - X[u] * dep[x] + X[u] * dep[fa] - s[u] + X[u] * dep[u];
    return res;
  };
  for (const int &v: g[u]) if (v != fa)
    if (v != son[u]) {
      for (int i = dfn[v]; i <= ed[v]; ++i) {
        int x = idfn[i];
        max_eq(ans, ask(x));
      }
      for (int i = dfn[v]; i <= ed[v]; ++i) {
        int x = idfn[i];
        h.add(X[x], s[x]);
      }
    }
  h.add(X[u], s[u]);
  max_eq(ans, ask(u));
  if (!keep) h.clear();
} // sack2
int main(void) {
  scanf("%d", &n);
  for (int i = 1; i <= n - 1; ++i) {
    int u, v;
    scanf("%d%d", &u, &v);
    g[u].push_back(v);
    g[v].push_back(u);
  }
  for (int i = 1; i <= n; ++i)
    scanf("%lld", &a[i]);
  sack_init(1, 0);
  sack1(1, 0, false);
  sack2(1, 0, true);
  printf("%lld\n", ans);
  exit(EXIT_SUCCESS);
} // main

习题 1 Digit Tree

习题 2 [IOI2011]Race

其它奇奇怪怪的杂题

例题 1

题目描述

题目背景

本题所用主要算法为 dsu on tree ,但没有那么裸。

题目描述

给定一棵树。

[L,R] 表示树上序号在 [L,R] 内的点的集合。

同时,令函数 F({S}) 表示令集合 S 内的点联通的需要的最小边数。

问题则是求:

i=1nj=inF([i,j])

输入格式

1 行,一个整数 n 表示点数。

2n 行,每行两个整数 u,v 表示结点 u 到结点 v 有一条边。

输出格式

一行一个整数,表示答案,意义见题目描述。

输入输出样例

输入 #1

4
1 4
1 3
2 4

输出 #1

16

说明/提示

对于 100% 的数据,满足 1n100,000

​ 直接优化枚举区间的过程比较难,考虑对每条边分别考虑贡献。

​ 于是我们考虑暴力。枚举点 z 时计算 z 父边的贡献,此时可以将所有的点分为 在/不在 z 子树内这两大类。设 bu=[uSz],那么 z 父边的贡献为 W=1lrn[0<lirbirl],即区间中包含 01 的区间个数。考虑反面计算,算出区间中只包含 01 的区间个数,然后用区间总数减一下。这显然可以直接 O(n) 计算。于是,我们就得到了一个 O(n2) 做法。

​ 注意到在 z 子树内,满足 bi=1i 的个数为 |Sz|。那么上面的暴力计算就是 O(|Sz|)。于是可以考虑用 dsu on tree 优化。我们现在要做的就是维护一个数据结构,支持:

  1. 将位置 i 处赋值为 1
  2. 计算有多少个区间内只有 01

显然我们只要维护 01 的极长连续段即可。对于 1 可以用并查集,对于 0 可以用 set 维护。显然这个容器满足可继承性,清空时只需将 set 初始化成一个大区间。

​ 总时间复杂度:O(nlog2n)

参考代码

#include <bits/stdc++.h>
using namespace std;
static constexpr int Maxn = 1e5 + 5;
int n;
int64_t ans;
vector<int> g[Maxn];
int sz[Maxn], son[Maxn], dep[Maxn];
int dfn[Maxn], ed[Maxn], idfn[Maxn], dfn_index;
void sack_init(int u, int fa, int depth = 1) {
  sz[u] = 1, son[u] = 0, dep[u] = depth;
  idfn[dfn[u] = ++dfn_index] = u;
  for (const int &v: g[u]) if (v != fa) {
    sack_init(v, u, depth), sz[u] += sz[v];
    if (son[u] == 0 || sz[v] > sz[son[u]]) son[u] = v;
  }
  ed[u] = dfn_index;
} // sack_init
inline int64_t calc(int x) { return (int64_t)x * (x + 1) / 2; }
int64_t curAns;
set<pair<int, int>> s;
int m, b[Maxn], v[Maxn];
int fa[Maxn], siz[Maxn];
int fnd(int x) { return fa[x] == x ? x : fa[x] = fnd(fa[x]); }
void unite(int x, int y) {
  x = fnd(x), y = fnd(y);
  if (x == y) return ;
  curAns -= calc(siz[x]);
  curAns -= calc(siz[y]);
  if (siz[x] < siz[y]) swap(x, y);
  fa[y] = x, siz[x] += siz[y];
  curAns += calc(siz[x]);
} // unite
void clear(void) {
  curAns = (int64_t)n * (n + 1) / 2;
  s.clear(), s.insert({1, n});
  for (; m; --m) v[b[m]] = 0;
} // clear
void insert(int x) {
  v[b[++m] = x] = 1;
  fa[x] = x, siz[x] = 1;
  curAns += calc(siz[fnd(x)]);
  if (v[x - 1]) unite(x - 1, x);
  if (v[x + 1]) unite(x + 1, x);
  auto it = prev(s.upper_bound({x, n + 1}));
  int l = it->first, r = it->second; s.erase(it);
  curAns -= calc(r - l + 1);
  if (l <= x - 1) curAns += calc(x - l), s.insert({l, x - 1});
  if (x + 1 <= r) curAns += calc(r - x), s.insert({x + 1, r});
} // insert
void sack(int u, int fa, bool keep = false) {
  for (const int &v: g[u]) if (v != fa)
    if (v != son[u]) sack(v, u, false);
  if (son[u]) sack(son[u], u, true);
  insert(u);
  for (const int &v: g[u]) if (v != fa)
    if (v != son[u]) {
      for (int i = dfn[v]; i <= ed[v]; ++i)
        insert(idfn[i]);
    }
  int64_t res = (int64_t)n * (n + 1) / 2 - curAns;
  ans += res;
  if (!keep) clear();
} // sack
int main(void) {
  scanf("%d", &n);
  for (int i = 2; i <= n; ++i) {
    int u, v;
    scanf("%d%d", &u, &v);
    g[u].push_back(v);
    g[v].push_back(u);
  }
  dfn_index = 0; sack_init(1, 0);
  clear(); sack(1, 0, true);
  printf("%lld\n", ans);
  exit(EXIT_SUCCESS);
} // main

习题 1 射手座之日

题目描述

题目背景

为了报春日抢电脑的一箭之仇,电脑社团的同学向SOS团发起了挑战!他们声称:如果SOS团可以在新的电脑游戏——《Sagittarius》中取胜,就送给SOS团一人一台新式电脑;反之SOS团要归还抢走的那台电脑。为了保护电脑中学姐换衣服的照片,你一定要获胜!

题目描述

游戏《Sagittarius》是一个多人闯关游戏,游戏有两个组成部分——地图 和关卡表。地图是一个n个结点的有根树,每个结点i都有权值xi,1号结点 为根;关卡表是一个{1,2,3,...,n}的排列ai。每一个回合中,系统会选取关卡表中一段连续的区间{al,al+1,,ar},设这些结点的lcak,则这次游戏胜利的话玩家将会得到xk的收益。如果进行了多次游戏,同一个结点的收益可以被重复获得。春日和你的游戏技术都非常弱,但是SOS团中有一个人形自走挂:长门有希。因此你们每次都可以获胜。现在春日想要知道,对于所有不同的回合,你们可以获得的收益之和是多少?

输入格式

1行一个数n,表示结点的个数。

第二行n1个数,第i个数是pi+1pi表示结点i的父亲是pi。数据保证pi<i

第三行n个数,a1,a2,,an,表示关卡表。数据保证这是一个排列。

第四行n个数,x1,x2,,xn,表示结点的权值。

输出格式

输出一个数表示答案。即对于所有可能的回合,你们能获得的总收益是多少。

输入输出样例

输入 #1

5
1 1 1 1 
5 2 3 1 4 
31244 44588 57025 99626 20260 

输出 #1

565183

说明/提示

对于20%的数据,n100

对于40%的数据,n2000

对于60%的数据,n50000

对于另外20%的数据,排列ai是用如下的算法生成的:从一号点始对树 做dfs,到达一个结点的时候输出这个结点。

对于全部数据,n200000, 0xi100000, pi<i, ai是一个排列。

来自毒瘤出题人ljt12138。

长链剖分

​ 考虑到某种树上问题,直接做需要枚举每一个结点 u,然后在子树 u 内"搞事情",其中"搞事情"这一部分运算次数的数量级为 O(mxdu),于是暴力做法的总运算次数为 O(uVmxdu)=O(n2)

​ 那么这种方法就有可能可以用长链剖分优化。和 dsu on tree 类似,我们的出发点同样是尽可能的减少运算次数,但我们又只能继承"重儿子",因此同样可以将"重儿子"选为运算次数最多的那个儿子。设 hu 为点 u 的"重儿子",则应该有 mxdhu=maxvchu{mxdv},即 hu=dsonu。因此总运算次数的数量级为 O(uV(vchumxdvmxddsonu))=O(n)

优化子树DP转移

​ 考虑 fu,k 表示子树 u 内到 u 距离为 k 的某个值,转移类似于树形背包,且 fu,k1fv,k2 合并时的复杂度为 O(min{k1,k2})。那么这个树形DP就极有可能可以用长链剖分优化。大致的思想就是在DP转移时先计算 fdsonu,将其直接继承到 fu 上,然后再枚举 u 其它的儿子 v 转移。

​ 大致代码如下:

void longsack(Node u, Node fa) {
  if (u.dson.exist)
    longsack(u.dson, u), f[u][1 .. $] = f[u.dson];
  for (Node v: u.Childs)
    if (v != u.dson)
      longsack(v, u), merge(f[u][1 .. $], f[v], v.maxDep);
  f[u][0] = calc_f(u, 0);
} // longsack

在具体实现时我们可以使用指针来模拟DP数组的偏移。

例题 1 Dominant Indices

​ 设 fu,k 为子树 u 内到 u 距离为 k 的结点数,则有 fu,k=vchufv,k1,边界为 fu,0=1

​ 考虑用长链剖分优化。在DP数组合并时顺便记录答案即可。时间复杂度为 O(n)

参考代码
#include <bits/stdc++.h>
using namespace std;
template<typename _Tp> _Tp &min_eq(_Tp &x, const _Tp &y) { return x = min(x, y); }
template<typename _Tp> _Tp &max_eq(_Tp &x, const _Tp &y) { return x = max(x, y); }
static constexpr int Maxn = 1e6 + 5;
int n, ans[Maxn];
vector<int> g[Maxn];
int len[Maxn], son[Maxn];
void longsack_init(int u, int fa, int depth = 1) {
  len[u] = 0, son[u] = 0;
  for (const int &v: g[u]) if (v != fa) {
    longsack_init(v, u, depth + 1);
    max_eq(len[u], len[v]);
    if (son[u] == 0 || len[v] > len[son[u]])
      son[u] = v;
  }
  ++len[u];
} // longsack_init
int tmp_dp[Maxn * 2], *tmp_dp_ptr = tmp_dp;
int *dp[Maxn];
void longsack(int u, int fa) {
  if (son[u]) {
    dp[son[u]] = dp[u] + 1;
    longsack(son[u], u);
    ans[u] = ans[son[u]] + 1;
  }
  dp[u][0]++;
  for (const int &v: g[u]) if (v != fa)
    if (v != son[u]) {
      dp[v] = tmp_dp_ptr; tmp_dp_ptr += len[v] + 1;
      longsack(v, u);
      for (int i = 0; i < len[v]; ++i) {
        dp[u][i + 1] += dp[v][i];
        if (dp[u][i + 1] > dp[u][ans[u]])
          ans[u] = i + 1;
        else if (dp[u][i + 1] == dp[u][ans[u]] && ans[u] > i + 1)
          ans[u] = i + 1;
      }
    }
  if (dp[u][ans[u]] <= 1) ans[u] = 0;
} // longsack
int main(void) {
  scanf("%d", &n);
  for (int i = 1; i <= n - 1; ++i) {
    int u, v;
    scanf("%d%d", &u, &v);
    g[u].push_back(v);
    g[v].push_back(u);
  }
  longsack_init(1, 0);
  dp[1] = tmp_dp_ptr; tmp_dp_ptr += len[1] + 1;
  longsack(1, 0);
  for (int i = 1; i <= n; ++i)
    printf("%d\n", ans[i]);
  exit(EXIT_SUCCESS);
} // main

例题 2 Maximum Weight Subset

​ 设 fu,i 表示子树 u 内所选点到 u 的最近距离为 i 时的答案。则有背包转移

fu,i=max{fv,i1+maxj>ifu,j,fu,i+maxjifv,j}

显然这个树形背包的时间复杂度是 O(n3),没有达到“要求”。

​ 考虑用后缀和优化。设 fu,i 表示子树 u 内所选点到 u 的最近距离 i 时的答案,那么DP转移就变成了:

fu,i=max{fv,i1+fu,max{i,ki},fu,i+fv,max{i,ki}1}

于是转移的时间复杂度就变成了 O(n2)

​ 考虑使用长链剖分优化。时间复杂度为 O(n)。需要注意一下的是边界问题。

参考代码

#include <bits/stdc++.h>
using namespace std;
template<typename _Tp> _Tp &min_eq(_Tp &x, const _Tp &y) { return x = min(x, y); }
template<typename _Tp> _Tp &max_eq(_Tp &x, const _Tp &y) { return x = max(x, y); }
static constexpr int Maxn = 205;
int n, k, a[Maxn];
vector<int> g[Maxn];
int sz[Maxn], dep[Maxn], mxd[Maxn], dson[Maxn];
void sack_init(int u, int fa, int depth) {
  dep[u] = depth, sz[u] = 1, dson[u] = 0;
  for (const int &v: g[u]) if (v != fa) {
    sack_init(v, u, depth + 1), sz[u] += sz[v];
    if (dson[u] == 0 || mxd[v] > mxd[dson[u]]) dson[u] = v;
  } mxd[u] = (dson[u] == 0 ? 0 : mxd[dson[u]]) + 1;
} // sack_init
int tf[Maxn * 2], *tfp = tf, *f[Maxn];
void longsack(int u, int fa) {
  if (dson[u] != 0) {
    f[dson[u]] = f[u] + 1;
    longsack(dson[u], u);
  }
  f[u][0] = max(a[u] + (k + 1 <= mxd[u] ? f[u][k + 1] : 0), f[u][1]);
  for (const int &v: g[u]) if (v != fa && v != dson[u]) {
    f[v] = tfp, tfp += mxd[v] + 1;
    longsack(v, u);
    static int h[Maxn];
    fill(h, h + mxd[v] + 1, 0);
    for (int i = 0; i <= mxd[v]; ++i) {
      max_eq(h[i], f[u][i] + (k - i > mxd[v] ? 0 : f[v][max(i, k + 1 - i) - 1]));
      if (i != 0) max_eq(h[i], (k + 1 - i > mxd[u] ? 0 : f[u][max(i, k + 1 - i)]) + f[v][i - 1]);
    }
    for (int i = mxd[v], mx = 0; i >= 0; --i)
      max_eq(f[u][i], max_eq(mx, h[i]));
  }
} // longsack
int main(void) {
  scanf("%d%d", &n, &k);
  for (int i = 1; i <= n; ++i) scanf("%d", &a[i]);
  for (int i = 2, u, v; i <= n; ++i)
    scanf("%d%d", &u, &v), g[u].push_back(v), g[v].push_back(u);
  sack_init(1, 0, 1);
  f[1] = tfp, tfp += mxd[1] + 1;
  longsack(1, 0);
  printf("%d\n", *max_element(f[1] + 0, f[1] + n + 1));
  exit(EXIT_SUCCESS);
} // main

例题 3 [POI2014]HOT-Hotels 加强版

​ 首先发现符合条件的 (x,y,z) 只有可能是以下这种模型:

					    b
					   / \
					  a   z
					 / \
					x   y

其中 a=lca(x,y),b=lca(x,y,z) 满足条件 dis(x,a)=dis(y,a)=dis(a,b)+dis(z,b)

​ 我们在结点 b 中计算这样的贡献。设 fu,k 表示子树 u 内到 u 距离为 k 的结点的个数,hu,k 表示子树 u 内满足 dis(x,a)=dis(y,a)=dis(b,a)+k 的无序结点对 (x,y) 的数量。则有背包转移:

hu,khu,k+hv,k+1+fu,kfv,k1fu,kfu,k+fv,k1

可以在背包转移时计算答案。于是我们得到了一个时间复杂度为 O(n2) 的做法。

​ 考虑用长链剖分优化。时间复杂度为 O(n)。需要注意的一点是此题长链剖分的空间大小。由于 h 是反向继承,因此 h 的基数组的空间要开 4 倍。

参考代码

#include <bits/stdc++.h>
using namespace std;
static constexpr int Maxn = 1e5 + 5;
int n;
int64_t ans;
vector<int> G[Maxn];
int sz[Maxn], dep[Maxn], mxd[Maxn], dson[Maxn];
void sack_init(int u, int fa, int depth) {
  dep[u] = depth, sz[u] = 1, dson[u] = 0;
  for (const int &v: G[u]) if (v != fa) {
    sack_init(v, u, depth + 1), sz[u] += sz[v];
    if (dson[u] == 0 || mxd[v] > mxd[dson[u]]) dson[u] = v;
  } mxd[u] = (dson[u] == 0 ? 0 : mxd[dson[u]]) + 1;
} // sack_init
int64_t tp[Maxn * 4], *P = tp;
int64_t *f[Maxn], *g[Maxn];
void longsack(int u, int fa) {
  if (dson[u] != 0) {
    f[dson[u]] = f[u] + 1;
    g[dson[u]] = g[u] - 1;
    longsack(dson[u], u);
  }
  f[u][0] = 1; ans += g[u][0];
  for (const int &v: G[u]) if (v != fa && v != dson[u]) {
    f[v] = P, P += mxd[v] * 2, g[v] = P, P += mxd[v] * 2;
    longsack(v, u);
    for (int j = 0; j < mxd[v]; ++j) {
      ans += f[v][j] * g[u][j + 1];
      if (j > 0) ans += f[u][j - 1] * g[v][j];
    }
    for (int j = 0; j < mxd[v]; ++j) {
      g[u][j + 1] += f[v][j] * f[u][j + 1];
      if (j > 0) g[u][j - 1] += g[v][j];
      f[u][j + 1] += f[v][j];
    }
  }
} // longsack
int main(void) {
  scanf("%d", &n);
  for (int i = 2, u, v; i <= n; ++i)
    scanf("%d%d", &u, &v), G[u].push_back(v), G[v].push_back(u);
  sack_init(1, 0, 1);
  f[1] = P, P += mxd[1] * 2, g[1] = P, P += mxd[1] * 2;
  longsack(1, 0);
  printf("%lld\n", ans);
  exit(EXIT_SUCCESS);
} // main

例题 4 [湖南集训]谈笑风生

​ 显然可以将贡献拆成两种情况计算:

  1. ab 不知道高到哪里去;
  2. ba 不知道高到哪里去。

第二种情况是很简单的,考虑如何计算第一种的贡献。

​ 设 fu,k=vSuvu[dis(u,v)k](|Sv|1) 表示结点 uk 时的答案。则有DP转移 fu,k=vchu(fv,k1+|Sv|1)。该做法时间复杂度为 O(n2)

​ 考虑用长链剖分优化。注意在继承 dson 时有全局加的过程,打懒标记即可。时间复杂度为 O(n)。另一点要注意的是边界 fu,0 的问题。

参考代码

#include <bits/stdc++.h>
using namespace std;
static constexpr int Maxn = 3e5 + 5;
int n, m;
int64_t ans1[Maxn], ans2[Maxn];
vector<int> g[Maxn];
vector<pair<int, int>> q[Maxn];
int sz[Maxn], dep[Maxn], mxd[Maxn], dson[Maxn];
void sack_init(int u, int fa, int depth) {
  dep[u] = depth, sz[u] = 1, dson[u] = 0;
  for (const int &v: g[u]) if (v != fa) {
    sack_init(v, u, depth + 1), sz[u] += sz[v];
    if (dson[u] == 0 || mxd[v] > mxd[dson[u]]) dson[u] = v;
  } mxd[u] = (dson[u] == 0 ? 0 : mxd[dson[u]]) + 1;
} // sack_init
int64_t tf[Maxn * 2], *tfp = tf;
int64_t *f[Maxn], lz[Maxn];
void longsack(int u, int fa) {
  if (dson[u] != 0) {
    f[dson[u]] = f[u] + 1;
    longsack(dson[u], u);
    int64_t nlz = lz[dson[u]] + sz[dson[u]] - 1;
    lz[u] += nlz, f[u][0] -= nlz;
  }
  for (const int &v: g[u]) if (v != fa && v != dson[u]) {
    f[v] = tfp, tfp += mxd[v] + 1;
    longsack(v, u);
    for (int j = 0; j < mxd[v]; ++j)
      f[u][j + 1] += f[v][j];
    int64_t nlz = lz[v] + sz[v] - 1;
    lz[u] += nlz, f[u][0] -= nlz;
  }
  for (const auto &[k, id]: q[u])
    ans1[id] = f[u][min(k, mxd[u] - 1)] + lz[u];
} // longsack
int main(void) {
  scanf("%d%d", &n, &m);
  for (int i = 2, u, v; i <= n; ++i)
    scanf("%d%d", &u, &v), g[u].push_back(v), g[v].push_back(u);
  sack_init(1, 0, 1);
  for (int i = 1, u, k; i <= m; ++i) {
    scanf("%d%d", &u, &k), q[u].push_back({k, i});
    ans2[i] = (int64_t)min(dep[u] - 1, k) * (sz[u] - 1);
  }
  f[1] = tfp, tfp += mxd[1] + 1;
  longsack(1, 0);
  for (int i = 1; i <= m; ++i)
    printf("%lld\n", ans1[i] + ans2[i]);
  exit(EXIT_SUCCESS);
} // main

例题 5 数树上块

​ 设 fu,j 表示在子树 u 中选出一个包含 u 的连通块,并且这个连通块中的点到 u 的最远距离为 j 的方案数。则有背包转移

fu,max{i,j+1}fu,max{i,j+1}+fu,ifv,j(where i+j<k is held)

直接转移的话是树形背包的复杂度 O(n2)

​ 考虑使用长链剖分优化。继承 dson 是很容易的;关键的问题是如何优化转移。转移时枚举 j ,然后我们考虑 fv,j 对哪些 fu,i 有贡献。容易发现,当 ijfv,j 只对 fu,j+1 有贡献,并且此时贡献是 fu,j+1fu,j+1+fv,ji=1min{j,kj1}fu,i;而当 i>jfv,j 的贡献区间为 [j+1,k(j+1)],并且此时贡献是 i[j+1,k(j+1)],fu,ifu,i+fu,ifv,j。于是在序列 fu 上要执行的操作就是前缀求和区间乘。对于每一条长链使用线段树维护即可。时间复杂度为 O(nlogn)

参考代码

#include <bits/stdc++.h>
using namespace std;
static constexpr int mod = 998244353;
inline int add(int x, int y) { return (x += y) >= mod ? x - mod : x; }
inline int mul(int x, int y) { return (int64_t)x * y % mod; }
inline int &add_eq(int &x, int y) { return x = add(x, y); }
inline int &mul_eq(int &x, int y) { return x = mul(x, y); }
static constexpr int Maxn = 5e5 + 5;
int n, k, ans;
struct Edge { int to, nxt; } e[Maxn << 1]; int head[Maxn], en;
void add_edge(int u, int v) { e[++en] = (Edge){v, head[u]}, head[u] = en; }
int dep[Maxn], par[Maxn], top[Maxn], ed[Maxn], son[Maxn], mxd[Maxn];
void pre_sack_1(int u, int fa, int depth) {
  dep[u] = depth, par[u] = fa; ed[u] = u, son[u] = 0;
  for (int i = head[u], v; i; i = e[i].nxt) if ((v = e[i].to) != fa)
    pre_sack_1(v, u, depth + 1), dep[ed[v]] > dep[ed[u]] && (ed[u] = ed[v], son[u] = v);
} // pre_sack_1
void pre_sack_2(int u, int fa, int topv) {
  top[u] = topv, mxd[u] = dep[ed[u]] - dep[u] + 1;
  if (son[u] != 0) pre_sack_2(son[u], u, topv);
  for (int i = head[u], v; i; i = e[i].nxt) if ((v = e[i].to) != fa)
    if (v != son[u]) pre_sack_2(v, u, v);
} // pre_sack_2
static constexpr int MaxN = Maxn * 8;
int ls[MaxN], rs[MaxN], tr[MaxN], lz[MaxN], rt[Maxn], tn;
inline void apply(int p, int v) { mul_eq(tr[p], v), mul_eq(lz[p], v); }
void build(int &p, int l, int r) {
  p = ++tn; tr[p] = 0, lz[p] = 1;
  if (l == r) return ;
  build(ls[p], l, (l + r) / 2);
  build(rs[p], (l + r) / 2 + 1, r);
} // build
void modify(int p, int l, int r, int L, int R, int v) {
  if (L <= l && r <= R) return apply(p, v);
  if (lz[p] != 1) apply(ls[p], lz[p]), apply(rs[p], lz[p]), lz[p] = 1;
  if (L <= (l + r) / 2) modify(ls[p], l, (l + r) / 2, L, R, v);
  if ((l + r) / 2 < R) modify(rs[p], (l + r) / 2 + 1, r, L, R, v);
  tr[p] = add(tr[ls[p]], tr[rs[p]]);
} // modify
void update(int p, int l, int r, int x, int v) {
  if (l == r) return add_eq(tr[p], v), void();
  if (lz[p] != 1) apply(ls[p], lz[p]), apply(rs[p], lz[p]), lz[p] = 1;
  if (x <= (l + r) / 2) update(ls[p], l, (l + r) / 2, x, v);
  else update(rs[p], (l + r) / 2 + 1, r, x, v);
  tr[p] = add(tr[ls[p]], tr[rs[p]]);
} // update
int ask(int p, int l, int r, int L, int R) {
  if (L == l && r == R) return tr[p];
  if (lz[p] != 1) apply(ls[p], lz[p]), apply(rs[p], lz[p]), lz[p] = 1;
  if (R <= (l + r) / 2) return ask(ls[p], l, (l + r) / 2, L, R);
  if (L > (l + r) / 2) return ask(rs[p], (l + r) / 2 + 1, r, L, R);
  return add(ask(ls[p], l, (l + r) / 2, L, (l + r) / 2),
             ask(rs[p], (l + r) / 2 + 1, r, (l + r) / 2 + 1, R));
} // ask
void sack(int u, int fa) {
  // f'_{u,\max_{i,j+1}}+=f_{u,i}f_{v,j}
  if (top[u] == u) build(rt[u], dep[top[u]], dep[ed[u]]);
  if (son[u] != 0) rt[son[u]] = rt[u], sack(son[u], u);
  update(rt[u], dep[top[u]], dep[ed[u]], dep[u], 1);
  for (int ei = head[u], v; ei; ei = e[ei].nxt)
    if ((v = e[ei].to) != fa) if (v != son[u]) {
      sack(v, u);
      int dv = min(mxd[v], k);
      static int fv[Maxn];
      for (int j = 0; j < dv; ++j)
        fv[j] = ask(rt[v], dep[top[v]], dep[ed[v]], dep[v] + j, dep[v] + j);
      // case <1>: i<j+1 => i\le j
      static int f1[Maxn];
      for (int j = 0; j < dv; ++j) {
        int imax = min(j, k - (j + 1)); // imax\le j
        int t = ask(rt[u], dep[top[u]], dep[ed[u]], dep[u], dep[u] + imax);
        f1[j] = mul(t, fv[j]);
      }
      // case <2>: i\ge j+1 => i>j
      int L = dep[u], R = dep[ed[u]], s = 0;
      for (int j = 0; j < dv; ++j) {
        int l = dep[u] + j + 1, r = min(dep[ed[u]], dep[u] + k - (j + 1));
        if (l != L) modify(rt[u], dep[top[u]], dep[ed[u]], L, l - 1, add(s, 1));
        if (r != R) modify(rt[u], dep[top[u]], dep[ed[u]], r + 1, R, add(s, 1));
        add_eq(s, fv[j]); L = l, R = r;
        if (L >= R) break;
      }
      if (L <= R) modify(rt[u], dep[top[u]], dep[ed[u]], L, R, add(s, 1));
      //  contribution of case <1>
      for (int j = 0; j < dv; ++j)
        update(rt[u], dep[top[u]], dep[ed[u]], dep[v] + j, f1[j]);
    }
  if (dep[u] + k < dep[ed[u]])
    modify(rt[u], dep[top[u]], dep[ed[u]], dep[u] + k + 1, dep[ed[u]], 0);
  add_eq(ans, ask(rt[u], dep[top[u]], dep[ed[u]], dep[u], min(dep[ed[u]], dep[u] + k)));
} // sack
int main(void) {
  scanf("%d%d", &n, &k);
  for (int i = 2, u, v; i <= n; ++i)
    scanf("%d%d", &u, &v), add_edge(u, v), add_edge(v, u);
  pre_sack_1(1, 0, 1), pre_sack_2(1, 0, 1);
  tn = 0, sack(1, 0);
  printf("%d\n", ans);
  exit(EXIT_SUCCESS);
} // main

例题 6 NFLSOJ #12738. 光明

题目描述

1.1 题目描述

​ 为了抽象人们向光明前进的脚步,我们建立一个具有 n 个点,点的编号分别为 1,,n 的有向图模型:对于所有 i[2,n],第 i 个点向第 fi(1fi<i) 个点连了一条边,代表一个人迈出的脚步。每个人都有自己的目标(用图上的一个节点 u 表示),并会为之付出一定的努力(用一个非负整数 i 表示)。我们将这样的计划抽象成二元组 (u,i),表示这个人想从起点为止开始恰好走 i 步到达 u,并用 f(u,i) 评估这个计划的可行程度,即有多少个起点满足恰好走 i 步能够到达 u。注意如果从一个起点出发到达了点 1 但还没走够恰 i 步则这个起点是失败的。

​ 来计算一下吧:给定正整数 k,请计算前 k 种最可行的方案,也即前 k 大的 f(u,i)。我们不需要知道它们分别是什么,我们只需要知道它们的和。

1.10 数据范围

​ 对 100% 的数据,保证 1n3×106,1k1018

​ 如果要直接维护前 k 大的 f(u,i) 是困难的;注意到 f(u,i)n 因此可以维护 f(u,i) 的桶。

​ 我们考虑在长链剖分求 f(u,i) 的过程中顺便维护桶。一种思路是考虑 f(u,i) 会贡献到哪些位置。可以发现 f(u,i) 贡献到的位置是其祖先序列上的一个区间。考虑长链剖分的合并过程,假设 fu,ifv,i1 合并为 fu,i,那么容易发现 fu,i,fv,i1u 祖先内就没有贡献了,而取而代之的则是 fu,i。因此我们可以在长链剖分合并的时候动态维护这个桶。继承重儿子没什么好说的,直接继承就是了。总时间复杂度为 O(n)

参考代码

#include <bits/stdc++.h>
using namespace std;
static constexpr int Maxn = 3e6 + 5;
int n, par[Maxn], en, head[Maxn];
struct Edge { int to, nxt; } e[Maxn];
void add_edge(int u, int v) { e[++en] = (Edge){v, head[u]}, head[u] = en; }
int dep[Maxn], dson[Maxn], mxd[Maxn];
int64_t k, ans, buc[Maxn];
void sack_init(int u) {
  dson[u] = 0, mxd[u] = 1;
  for (int i = head[u], v; i; i = e[i].nxt) {
    sack_init(v = e[i].to);
    if (mxd[u] < mxd[v] + 1)
      mxd[u] = mxd[v] + 1, dson[u] = v;
  }
} // sack_init
int *f[Maxn], tf[Maxn], top;
void longsack(int u) {
  if (dson[u] != 0) {
    f[dson[u]] = f[u] + 1;
    longsack(dson[u]);
  }
  buc[f[u][0] = 1] += dep[u];
  for (int i = head[u], v; i; i = e[i].nxt)
    if ((v = e[i].to) != dson[u]) {
      f[v] = tf + top, top += mxd[v];
      longsack(v);
      for (int j = 0; j < mxd[v]; ++j) {
        buc[f[u][j + 1]] -= dep[u];
        buc[f[v][j + 0]] -= dep[u];
        f[u][j + 1] += f[v][j + 0];
        buc[f[u][j + 1]] += dep[u];
      }
    }
} // dfs
int main(void) {
  freopen("light.in", "r", stdin);
  freopen("light.out", "w", stdout);
  extern uint32_t readu32(void);
  scanf("%d%lld", &n, &k); dep[1] = 1;
  for (int i = 2; i <= n; ++i)
    add_edge(par[i] = readu32(), i), dep[i] = dep[par[i]] + 1;
  sack_init(1);
  f[1] = tf + top, top += mxd[1];
  longsack(1);
  for (int i = n; i >= 1; --i) {
    int64_t t = min(k, buc[i]);
    ans += t * i, k -= t;
  }
  printf("%lld\n", ans);
  exit(EXIT_SUCCESS);
} // main
namespace fastio {
  static constexpr int BUF_SIZE = 1 << 21;
  char buf[BUF_SIZE], *p1, *p2, ch;
} // fastio
using namespace fastio;
#define getc() p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, BUF_SIZE, stdin), p1 == p2) ? EOF : *p1++;
uint32_t readu32(void) {
  uint32_t ret = 0; char c = getc();
  while (c < '0' || c > '9') c = getc();
  while (c >= '0' && c <= '9') ((ret += ret << 2) <<= 1) += c - '0', c = getc();
  return ret;
} // readu32

例题 7 「十二省联考 2019」希望

咕咕咕

习题 1 Blood Cousins

习题 2 Freezing with Style

习题 3 Treeland Tour

习题 4 NFLSOJ #12449. 上升

题目描述

3.1 Statement

给定一个 n 个点的无根树,每个点有一个标号 wi

定义一条链 (u,v) 的权值为按照从 uv 的顺序把这条链的标号写下来之后,得到的序列的最长上升子序列长度。

你需要删掉一个点,使得剩下的链的权值的最大值最小。

3.6 Constraints

对于 100% 的数据,保证 1n5×105,1wi109

习题 5 「SNOI2019」网络

静态点分治

​ 静态点分治主要用于解决一类树上路径问题。

核心思想 每次选一个点作为当前连通块的根结点,在这一层只统计经过该点的路径。然后各个子树之间便不再有贡献,于是可以删除根结点后对各个子树分治。

​ 每次只需要选取连通块重心作为分治中心,则可以保证各个子任务的规模都不超过原问题的一半。若每一次统计经过根结点的路径的数量级为 O(n),那么根据 Master Theorem 知枚举路径的运算次数数量级就是 T(n)=kT(nk)+O(n)=O(nlogn)

​ 假设当前分治的根结点为 z,则要统计的路径 (x,y) 满足 lca(x,y)=z。主要有两种方法:

  1. 若所求路径的条件及贡献是单元的,则可使用类似树形背包的方法,枚举 x,y 时先枚举 y,同时设置一个容器存放之前所有的 f(x) 满足 lca(x,y)=z
  2. 若所求路径的条件及贡献满足可加减性,则可以使用类似容斥的方法,先不考虑不能在同一棵子树内的限制,跑一遍答案后再在每个子树内单独跑一次把贡献减去即可。

由此也可以知道,静态点分治在统计树上路径时是优于 dsu on tree 的。在其它方面,前者不能做的事情是统计每棵子树内的路径,后者不能做的事情是对于每个点统计经过该点的路径 (后文中的一道例题会说到静态点分治是如何对于每个点统计经过该点的路径的)。

例题 1 [IOI2011]Race

​ 静态点分治的经典题。

​ 以当前分治中心为根,经过根的路径的长度可以转化为两个端点到根结点的距离。我们要凑出恰好为 k 的路径,由于 k 较小,可以直接开桶维护,然后用类似树形背包的方法转移即可。分治内每层的时间复杂度为 O(n),总时间复杂度为 O(nlogn)

参考代码
#include <stdio.h>
#include <string.h>

static char _in_buf[100000], *_in_p1 = _in_buf, *_in_p2 = _in_buf;
#define gc (__builtin_expect(_in_p1 == _in_p2, 0) && (_in_p2 = (_in_p1 = _in_buf) + \
        fread(_in_buf, 1, 100000, stdin), _in_p1 == _in_p2) ? -1 : *_in_p1 ++)
extern __inline int
__attribute__((__gnu_inline__, __always_inline__))
read(int *x) {
  register char ch = gc; *x = 0;
  while (ch < 48) ch = gc;
  while (ch > 47) *x = (*x << 3) + (*x << 1) + (ch ^ 48), ch = gc;
}

#define Maxn 200005
int n, k, ans;
struct Edge {
  int to, nxt, w;
} e[Maxn << 1];
int head[Maxn], tot_edge;
#define rint register int
#define fgraph for (rint _i_edge_ = head[u]; ~_i_edge_; _i_edge_ = e[_i_edge_].nxt)
#define v (e[_i_edge_].to)
#define w (e[_i_edge_].w)
int root, tot, mnsz;
int sz[Maxn];
bool vis[Maxn];
void get_root(int u, int par) {
  sz[u] = 1;
  int mx = 0;
  fgraph {
    if (v == par || vis[v]) continue;
    get_root(v, u), sz[u] += sz[v];
    (sz[v] > mx) && (mx = sz[v]);
  }
  (tot - sz[u] > mx) && (mx = tot - sz[u]);
  (mx < mnsz) && (mnsz = mx, root = u);
}
int qsz, pre;
int qdist[Maxn];
int qlen[Maxn];
int step[1000005];
void get_dist(int u, int par, int dist, int len) {
  if (dist > k) return ;
  qdist[++qsz] = dist;
  qlen[qsz] = len;
  fgraph {
    if (__builtin_expect(v == par || vis[v], false)) continue;
    get_dist(v, u, dist + w, len + 1);
  }
}
void tree_divide(int u) {
  vis[u] = true;
  step[0] = qsz = 0;
  pre = 0;
  fgraph {
    if (__builtin_expect(vis[v], false)) continue;
    get_dist(v, u, w, 1);
    for (rint i = pre + 1; i <= qsz; ++i)
      (qlen[i] + step[k - qdist[i]] < ans) && (ans = qlen[i] + step[k - qdist[i]]);
    for (rint i = pre + 1; i <= qsz; ++i)
      (qlen[i] < step[qdist[i]]) && (step[qdist[i]] = qlen[i]);
    pre = qsz;
  }
  for (rint i = 1; i <= qsz; ++i)
    step[qdist[i]] = 0x3f3f3f3f;
  fgraph {
    if (__builtin_expect(vis[v], false)) continue;
    root = -1, tot = sz[v];
    mnsz = 0x3f3f3f3f;
    get_root(v, u);
    tree_divide(root);
  }
}
#undef fgraph
#undef v
#undef w

int main() {
  read(&n), read(&k);
  memset(head, -1, sizeof(head));
  for (rint i = 1; i < n; ++i) {
    int u, v, w;
    read(&u), read(&v), read(&w);
    ++u, ++v;
    e[tot_edge] = (Edge){v, head[u], w};
    head[u] = tot_edge++;
    e[tot_edge] = (Edge){u, head[v], w};
    head[v] = tot_edge++;
  }
  memset(vis, false, sizeof(bool) * (n + 1));
  memset(step, 0x3f, sizeof(step));
  root = -1, tot = n, mnsz = 0x3f3f3f3f;
  get_root(1, -1);
  ans = 0x3f3f3f3f;
  tree_divide(root);
  if (ans < 0x3f3f3f3f) printf("%d\n", ans);
  else puts("-1");
  return 0;
}

例题 2 Close Vertices

​ 选取了分治中心 z 之后,令 wu 为连通块中点 u到根结点 z 的距离,则路径 (x,y) 符合要求当且仅当:

  1. lca(x,y)=z
  2. depx+depyL,这里我们定义 depz=0
  3. wx+wyW

注意到条件是单元的,因此可以考虑类似树形背包的方式枚举子树,枚举 y 时将 (depx,wx) 丢到容器里,查询时就是一个二维数点的问题。于是分治子问题可以转化为一个动态二维数点。动态二维数点的时间复杂度是 O(nlog2n) 的,原问题的时间复杂度是 O(nlog3n) 的。

​ 考虑优化。但动态二维数点的最低时间复杂度就是 Θ(nlog2n),无法降到更低。注意到分治子问题本身是静态的,而也正是由于类似树形背包的方式枚举子树才使得它变为动态的,因此我们可以考虑换个方式统计答案。注意到上述每一个条件满足可加减性,因此可以使用类似于容斥的方法。那么现在要统计的就是子树内满足 depx+depyLwx+wyW 的点对 (x,y) 的个数。可以按 w 排序后用树状数组维护 dep 类似于双指针的方法扫一遍整个数组。这样分治子问题的时间复杂度为 O(nlogn),原问题的时间复杂度就是 O(nlog2n)

参考代码
#include <bits/stdc++.h>
using namespace std;
template<typename _Tp> _Tp &min_eq(_Tp &x, const _Tp &y) { return x = min(x, y); }
template<typename _Tp> _Tp &max_eq(_Tp &x, const _Tp &y) { return x = max(x, y); }
static constexpr int Maxn = 1e5 + 5;
int n, LL, WW;
vector<pair<int, int>> g[Maxn];
bool visited[Maxn];
int64_t ans;
int totN, mnN, root, sz[Maxn];
void get_rooting(int u, int fa) {
  int mx = 0; sz[u] = 1;
  for (const auto &[v, w]: g[u]) if (!visited[v] && v != fa)
    get_rooting(v, u), sz[u] += sz[v], max_eq(mx, sz[v]);
  if (mnN > max_eq(mx, totN - sz[u]))
    mnN = mx, root = u;
} // get_rooting
inline int get_root(int u, int N) {
  totN = N, mnN = N, root = 0;
  get_rooting(u, 0); return root;
} // get_root
int m, L, W;
pair<int, int> buc[Maxn];
int64_t dfs_calc1(int u, int fa, int depth, int dis) {
  int64_t res = 0;
  if (depth && depth <= L && dis <= W) res++;
  for (const auto &[v, w]: g[u]) if (!visited[v] && v != fa)
    res += dfs_calc1(v, u, depth + 1, dis + w);
  return res;
} // dfs_calc1
void dfs_get(int u, int fa, int depth, int dis) {
  if (depth) buc[++m] = {depth, dis};
  for (const auto &[v, w]: g[u]) if (!visited[v] && v != fa)
    dfs_get(v, u, depth + 1, dis + w);
} // dfs_get
int bit[Maxn];
inline void add(int x, int w) {
  for (; x <= n; x += x & -x) bit[x] += w;
} // add
inline int ask(int x) {
  int r = 0;
  for (; x; x -= x & -x) r += bit[x];
  return r;
} // ask
int64_t divide_calc(int u, int Ll, int Ww) {
  L = Ll, W = Ww; m = 0;
  int64_t res = dfs_calc1(u, 0, 0, 0);
  for (const auto &[v, w]: g[u]) if (!visited[v])
    dfs_get(v, u, 1, w);
  sort(buc + 1, buc + m + 1, [&](auto lhs, auto rhs){
    return lhs.second < rhs.second;
  });
  int l = 1, r;
  for (l = 1, r = m; r >= 1; --r) {
    for (; l < r && buc[r].second + buc[l].second <= W; ++l)
      add(buc[l].first, 1);
    for (; l > r; --l)
      add(buc[l - 1].first, -1);
    res += ask(max(0, L - buc[r].first));
  }
  for (int i = 1; i < l; ++i)
    add(buc[i].first, -1);
  return res;
} // divide_calc
void divide(int u) {
  int64_t res = 0;
  res += divide_calc(u, LL, WW); visited[u] = true;
  for (const auto &[v, w]: g[u]) if (!visited[v])
    res -= divide_calc(v, LL - 2, WW - 2 * w);
  ans += res;
  for (const auto &[v, w]: g[u]) if (!visited[v])
    divide(get_root(v, sz[v]));
} // divide
int main(void) {
  scanf("%d%d%d", &n, &LL, &WW);
  for (int i = 2; i <= n; ++i) {
    int fa, w;
    scanf("%d%d", &fa, &w);
    g[fa].push_back({i, w});
    g[i].push_back({fa, w});
  }
  divide(get_root(1, n));
  printf("%lld\n", ans);
  exit(EXIT_SUCCESS);
} // main

例题 3 Palindromes in a Tree

​ 这个题要做的事情是对于每个点统计经过该点的路径。

​ 考虑当前分治层选的分治中心为 z,处理所有在当前连通块内且经过点 z 的路径。不难发现,路径 xy 对该路径上所有点都产生贡献。换句话说,对于一个点 x,所有一端在 x 子树内的路径对点 x 都产生了一个贡献。于是我们可以考虑树上差分,即对于每个点 x,统计所有一端为 x 的路径,最后再做一次子树和。

​ 至于如何统计一端为 x 的路径,这个是平凡的。压位后使用类似树形背包的枚举方式,同时开一个桶维护每一种 mask 个数即可。总时间复杂度为 O(||nlogn)

参考代码
#include <bits/stdc++.h>
using namespace std;
template<typename _Tp> _Tp &min_eq(_Tp &x, const _Tp &y) { return x = min(x, y); }
template<typename _Tp> _Tp &max_eq(_Tp &x, const _Tp &y) { return x = max(x, y); }
static constexpr int Maxn = 2e5 + 5;
static constexpr int Maxk = 20;
int n;
int64_t ans[Maxn];
char str[Maxn];
vector<int> g[Maxn];
bool visited[Maxn];
int totN, root, mnN, sz[Maxn];
void get_rooting(int u, int fa) {
  int mxS = 0; sz[u] = 1;
  for (const int &v: g[u]) if (v != fa && !visited[v])
    get_rooting(v, u), sz[u] += sz[v], max_eq(mxS, sz[v]);
  max_eq(mxS, totN - sz[u]);
  if (mnN > mxS) mnN = mxS, root = u;
} // get_rooting
int get_root(int u, int N) {
  totN = N, mnN = N;
  get_rooting(u, 0);
  return root;
} // get_root
void divide_calc(int);
void divide(int u) {
  divide_calc(u); visited[u] = true;
  for (const int &v: g[u]) if (!visited[v])
    divide(get_root(v, sz[v]));
} // divide
int buc[1 << Maxk];
void dfs_add(int u, int fa, int s, int c) {
  s ^= (1 << (str[u] - 'a')); buc[s] += c;
  for (const int &v: g[u]) if (v != fa && !visited[v])
    dfs_add(v, u, s, c);
} // dfs_add
int64_t dfs_calc(int u, int fa, int s) {
  s ^= (1 << str[u] - 'a');
  int64_t res = 0;
  res += buc[s];
  for (int i = 0; i < Maxk; ++i)
    res += buc[(1 << i) ^ s];
  for (const int &v: g[u]) if (v != fa && !visited[v])
    res += dfs_calc(v, u, s);
  ans[u] += res; return res;
} // dfs_calc
int64_t dfs_calc1(int u, int fa, int s) {
  s ^= (1 << str[u] - 'a');
  int64_t res = 0;
  res |= (s == 0);
  for (int i = 0; i < Maxk; ++i)
    res |= (s == (1 << i));
  for (const int &v: g[u]) if (v != fa && !visited[v])
    res += dfs_calc1(v, u, s);
  ans[u] += res; return res;
} // dfs_calc1
void divide_calc(int u) {
  vector<int> son;
  for (const int &v: g[u]) if (!visited[v])
    son.push_back(v);
  int64_t res = 0;
  for (int dir = 0; dir < 2; ++dir) {
    for (const int &v: son)
      res += dfs_calc(v, u, 0), dfs_add(v, u, 1 << (str[u] - 'a'), 1);
    for (const int &v: son)
      dfs_add(v, u, 1 << (str[u] - 'a'), -1);
    reverse(son.begin(), son.end());
  }
  ans[u] += res / 2; ++ans[u];
  for (const int &v: son)
    ans[u] += dfs_calc1(v, u, 1 << (str[u] - 'a'));
} // divide_calc
int main(void) {
  scanf("%d", &n);
  for (int i = 2, u, v; i <= n; ++i)
    scanf("%d%d", &u, &v), g[u].push_back(v), g[v].push_back(u);
  scanf("%s", str + 1);
  divide(get_root(1, n));
  for (int i = 1; i <= n; ++i)
    printf("%lld%c", ans[i], " \n"[i == n]);
  exit(EXIT_SUCCESS);
} // main

例题 4 Freezing with Style

​ 对于最大化所选集合中位数这一类题,有个经典的套路:二分答案 z 后将集合中所有 z 的数的权值赋为 1,所有 <z 的数的权值赋为 1 。若一个集合内所有数的权值的和 0 那么这个集合内所有数的中位数就 z

​ 于是按照上述,原题即可转化为:树上是否有一条路径的边数在 [L,R] 中且边权值和 0

​ 由于是路径问题,因此可以考虑点分治。假设当前分治中心为 z。可以按照类似树形背包的方式枚举子树,同时维护一个容器内装着之前所有点到点 z 的距离。那么对于当前子树内枚举的点 y,设其深度为 depy (这里有 depz=0 ),那么该路径的另一端点 x 要满足 LdepydepxRdepy,所查询的便是所有满足这个条件的 x 到点 z 的距离的最大值。注意观察这个条件式很像一个区间查询,于是我们考虑用线段树维护之前所有子树内每一个深度所对应的所有 xz 距离的最大值,那么枚举当前子树内的点 y 时便可以直接在线段树上查询区间 [Ldepy,Rdepy] 最大值。于是我们就得到了一个 O(nlog2n) 的点分治做法。加上原问题的二分,总时间复杂度是 O(nlog3n)

​ 考虑优化。注意上述做法在线段树上查询的区间 [Ldepy,Rdepy] 长度不变,因此我们感觉到线段树是多余的;而在枚举当前子树内结点 y 查询时是静止的过程,因此我们可以使用单调队列代替线段树。我们换一个顺序枚举当前子树,按照 bfs 序枚举,那么枚举到的点 y 就有 depy 单调不降,因此查询的区间就是一直向左移的。于是似乎直接使用单调队列维护区间最大值就行了?

​ 然而并不行。注意到单调队列的初始化是 O(N) 的,其中 N 是数组的长度。在这道题里面,单调队列的长度是之前所有子树的最深深度,因此直接使用单调队列维护区间最大值的做法会被一个蒲公英hack掉。那怎么办呢?这里就要运用一个叫做单调队列按秩合并trick。注意到之前所有子树的最深深度,于是可以联想到树上按秩合并。我们强制每一次加入的子树的最深深度是当前所有子树最深深度的最大值,那么这一分治层的时间复杂度就是 O(vCzmxdv)O(|Sz|)。于是点分治的时间复杂度就是 O(nlogn) 的,总时间复杂度是 O(nlog2n)

参考代码

#include <bits/stdc++.h>
using namespace std;
template<typename _Tp> _Tp &min_eq(_Tp &x, const _Tp &y) { return x = min(x, y); }
template<typename _Tp> _Tp &max_eq(_Tp &x, const _Tp &y) { return x = max(x, y); }
static constexpr int Maxn = 1e5 + 5;
static constexpr int inf = 0x3f3f3f3f;
int n, ll, rr;
struct Edge {
  int to, w; int64_t eW;
  Edge() = default;
  Edge(int to, int64_t W) : to(to), eW(W) { }
  void update(int64_t X) { w = (eW >= X ? 1 : -1); }
};
vector<Edge> g[Maxn];
namespace ndt {
  static constexpr int Maxn = ::Maxn;
  bool visited[Maxn];
  int totN, mnN, root, sz[Maxn];
  int calcsz(int u, int fa) {
    sz[u] = 1;
    for (const Edge &E: g[u])
      if (E.to != fa && !visited[E.to])
        sz[u] += calcsz(E.to, u);
    return sz[u];
  } // ndt::calcsz
  void get_rooting(int u, int fa) {
    int mx = 0; sz[u] = 1;
    for (const Edge &E: g[u])
      if (E.to != fa && !visited[E.to])
        get_rooting(E.to, u), sz[u] += sz[E.to], max_eq(mx, sz[E.to]);
    if (max_eq(mx, totN - sz[u]) < mnN)
      mnN = mx, root = u;
  } // ndt::get_rooting
  inline int get_root(int u) {
    totN = calcsz(u, 0), mnN = totN, root = -1;
    get_rooting(u, 0); return root;
  } // ndt::get_root
  void divide(int *par, int u, int fa = 0) {
    par[u] = fa; visited[u] = true;
    for (const Edge &E: g[u])
      if (!visited[E.to])
        divide(par, get_root(E.to), u);
  } // ndt::divide
  inline void build(int n, int *par) {
    fill(visited + 1, visited + n + 1, false);
    divide(par, get_root(1));
  } // ndt::build
} // namespace ndt
int vpar[Maxn], root;
vector<int> g_root[Maxn];
bool visited[Maxn];
pair<int, int> Ans;
int dep[Maxn], mxdep[Maxn], dis[Maxn];
int vf[Maxn], uf[Maxn];
int dfs_dep(int u, int fa, int depth, int W) {
  int mxd = dep[u] = depth; dis[u] = W;
  if (!vf[depth] || dis[vf[depth]] < dis[u]) vf[depth] = u;
  for (const Edge &E: g[u])
    if (!visited[E.to] && E.to != fa)
      max_eq(mxd, dfs_dep(E.to, u, depth + 1, W + E.w));
  return mxdep[u] = mxd;
} // dfs_dep
bool calc(int u) {
  vector<Edge> son;
  mxdep[u] = 0;
  for (const Edge &E: g[u])
    if (!visited[E.to]) {
      son.push_back(E);
      mxdep[E.to] = dfs_dep(E.to, u, 1, E.w);
      max_eq(mxdep[u], mxdep[E.to]);
    }
  sort(son.begin(), son.end(), 
      [&](const Edge &lhs, const Edge &rhs) {
    return mxdep[lhs.to] < mxdep[rhs.to];
  });
  dis[0] = -0x3f3f3f3f;
  fill(uf, uf + mxdep[u] + 1, 0);
  fill(vf, vf + mxdep[u] + 1, 0);
  uf[0] = u; int Mx = 0;
  for (const auto &E: son) {
    dfs_dep(E.to, u, 1, E.w);
    static int que[Maxn];
    int qh = 1, qe = 0;
    for (int j = min(rr, mxdep[E.to]), i = 0; j >= 0; --j) {
      for (; i <= Mx && i + j <= rr; que[++qe] = i++)
        while (qh <= qe && dis[uf[que[qe]]] < dis[uf[i]]) --qe;
      while (qh <= qe && que[qh] + j < ll) ++qh;
      if (qh <= qe && que[qh] + j <= rr && dis[uf[que[qh]]] + dis[vf[j]] >= 0)
        return Ans = {uf[que[qh]], vf[j]}, true;
    }
    Mx = mxdep[E.to];
    for (int j = 0; j <= Mx; ++j) {
      if (uf[j] == 0 || dis[vf[j]] > dis[uf[j]])
        uf[j] = vf[j];
    }
    fill(vf, vf + Mx + 1, 0);
  }
  return false;
} // calc
bool divide(int u) {
  visited[u] = true;
  if (calc(u)) return true;
  for (const int &v: g_root[u])
    if (divide(v)) return true;
  return false;
} // divide
bool calc_answer(int X) {
  for (int i = 1; i <= n; ++i)
    for (Edge &E: g[i]) E.update(X);
  fill(visited + 1, visited + n + 1, false);
  Ans = {-1, -1};
  return divide(root);
} // calc
int main(void) {
  scanf("%d%d%d", &n, &ll, &rr);
  int low = inf, high = -inf;
  for (int i = 1; i < n; ++i) {
    int u, v, w;
    scanf("%d%d%d", &u, &v, &w);
    g[u].push_back(Edge(v, w));
    g[v].push_back(Edge(u, w));
    min_eq(low, w);
    max_eq(high, w);
  }
  ndt::build(n, vpar);
  for (int i = 1; i <= n; ++i) {
    if (!vpar[i]) root = i;
    else g_root[vpar[i]].push_back(i);
  }
  int exact = low;
  while (low <= high) {
    int mid = (low + high) >> 1;
    bool validity = calc_answer(mid);
    if (validity) low = mid + 1, exact = mid;
    else high = mid - 1;
  }
  assert(calc_answer(exact));
  printf("%d %d\n", Ans.first, Ans.second);
  exit(EXIT_SUCCESS);
} // main

例题 5 NFLSOJ #12365.「NOIP2021模拟赛0820南外」发怒

题目描述

4.1 Description

​ 狂暴贴贴!

/fn 有一棵 n 个点的树,树上每个点有一个正整数权值 ai。定义点集的一个子集 S 是连通的,当且仅当在树上仅保留 S 内的点以及它们之间的边后,这张图是连通的。定义 S 的权值为其包含的所有结点的权值之积。

/fn 想要知道,这棵树上有多少非空的连通的且权值 m 的子集 S,答案对 109+7 取模。

4.5 Constraint

​ 对于全部数据,1n2000,1m106,aim

​ 首先有一个非常显然的树上背包的做法:设 fu,j 表示有多少个以 u 为根的树上连通块,满足连通块中所有 ai 的乘积恰好为 j。使用前缀和优化后便可以利用树上背包做到 O(nm)。但这显然过不去。

​ 考虑优化。注意到上述DP状态的第二维记录乘积恰好为 j 很浪费;我们只需要关心 j 最多还能乘多少不超过 m,而这个数为 mj。因此我们可以考虑换一个DP状态设计:设 fu,j 表示有多少个以 u 为根的树上连通块,满足 m 除以连通块中所有 ai 的乘积为 j。这样根据整除分块那套理论,此时DP状态的第二维就是 O(m) 的。于是树上背包的复杂度就降到了 O(m) 了。

​ 但这样真的就可以了么?并不是的。注意到我们现在记录的是 mj,这就出现了一个问题:即使我们已知 mj1mj2 也不能推出 mj1j2,也就是说现在的背包状态是不能合并的了。但我们仍然可以支持向目前集合中只加入一个数 x,因为我们可以通过已知 mjx 来推出 mjx。这使得我们直接树上背包变得非常棘手。

​ 怎么办呢?这时就要请出今天的主角:静态点分治。

​ 静态点分治其实还有另一类不常见的应用:统计树上的连通块。具体来说,每次在当前处理的分治连通块中统计包含分治中心的连通块。由于我们要统计的是树上连通块,因此一个点如果被选入树上连通块中,那么它在分治连通块内的祖先们也肯定要被选择。

​ 我们可以考虑一个树上依赖性背包。设 fi,j 表示在当前连通块中考虑了所有 DFSi 的结点,它们权值的乘积 x 满足 mx=j 的方案数 (这个 DFS 序可以视作已经考虑过的点的DP状态)。我们现在考虑如何转移。我们 DFS 到一个结点 x 并遍历它的一个儿子 y 时,我们将在当前背包中加入 ay 后的背包传给它的儿子 y,然后结点 y 带着这个背包到 y 的子树内逛一圈;上推的时候,再将 y 的背包值传给 x。这样的背包过程正确性是显然的,而且在转移过程中只有添加单个结点的操作,不会出现上述的问题。分治层内背包的时间复杂度为 O(nm),总时间复杂度为 O(nmlogn)。可以通过。

参考代码
#include <bits/stdc++.h>
using namespace std;
template<typename _Tp> _Tp &min_eq(_Tp &x, const _Tp &y) { return x = min(x, y); }
template<typename _Tp> _Tp &max_eq(_Tp &x, const _Tp &y) { return x = max(x, y); }
static constexpr int mod = 1e9 + 7;
inline int add(int x, int y) { return (x += y) >= mod ? x - mod : x; }
inline int sub(int x, int y) { return add(x, mod - y); }
inline int mul(int x, int y) { return (int64_t)x * y % mod; }
inline int &add_eq(int &x, int y) { return x = add(x, y); }
inline int &sub_eq(int &x, int y) { return x = sub(x, y); }
inline int &mul_eq(int &x, int y) { return x = mul(x, y); }
static constexpr int Maxn = 2005, Maxm = 1e6 + 5, MaxmS = 1e3 + 5;
int n, m, ans, a[Maxn], mn;
vector<int> g[Maxn];
int nd[Maxm], ind[Maxm];
bool visited[Maxn];
int totN, mnN, root, sz[Maxn];
int calcsz(int u, int fa) {
  sz[u] = 1;
  for (const int &v: g[u])
    if (v != fa && !visited[v])
      sz[u] += calcsz(v, u);
  return sz[u];
} // calcsz
void get_rooting(int u, int fa) {
  int mx = 0; sz[u] = 1;
  for (const int &v: g[u])
    if (v != fa && !visited[v])
      get_rooting(v, u), sz[u] += sz[v], max_eq(mx, sz[v]);
  if (max_eq(mx, totN - sz[u]) < mnN)
    mnN = mx, root = u;
} // get_rooting
inline int get_root(int u) {
  totN = calcsz(u, 0), mnN = totN, root = -1;
  get_rooting(u, 0); return root;
} // get_root
int dp[Maxn][MaxmS * 2];
void dfs(int u, int fa) {
  for (const int &v: g[u])
    if (v != fa && !visited[v]) {
      for (int i = 1; i <= mn; ++i)
        add_eq(dp[v][ind[nd[i] / a[v]]], dp[u][i]);
      dfs(v, u);
      for (int i = 1; i <= mn; ++i)
        add_eq(dp[u][i], dp[v][i]);
    }
} // dfs
void dfs_clear(int u, int fa) {
  memset(dp[u], 0, sizeof(dp[u]));
  for (const int &v: g[u])
    if (v != fa && !visited[v])
      dfs_clear(v, u);
} // dfs_clear
int calc(int u) {
  int ans = 0;
  dp[u][ind[m / a[u]]] = 1; dfs(u, 0);
  for (int i = 1; i <= mn; ++i)
    add_eq(ans, dp[u][i]);
  dfs_clear(u, 0);
  return ans;
} // calc
void divide(int u) {
  add_eq(ans, calc(u));
  visited[u] = true;
  for (const int &v: g[u])
    if (!visited[v])
      divide(get_root(v));
} // divide
int main(void) {
  freopen("fn.in", "r", stdin);
  freopen("fn.out", "w", stdout);
  scanf("%d%d", &n, &m);
  for (int l = 1; l <= m; l = (m / (m / l)) + 1)
    nd[ind[m / l] = ++mn] = m / l;
  for (int i = 1; i <= n; ++i) scanf("%d", &a[i]);
  for (int i = 2, u, v; i <= n; ++i)
    scanf("%d%d", &u, &v), g[u].push_back(v), g[v].push_back(u);
  memset(visited, false, sizeof(visited));
  divide(get_root(1));
  printf("%lld\n", ans);
  exit(EXIT_SUCCESS);
} // main

总结 若某个树上连通块类型的背包转移时不支持子树合并,但支持插入单点,那么可以考虑使用静态点分治将子树合并转为插入单点后使用树上依赖性背包。

例题 6 「NOI2014」购票

​ 设 fu 为结点 u 的答案。不难想到一个DP转移方程:fu=minv{fv+pudis(u,v)+qu},其中结点 v 满足条件:vu 的祖先,且 dis(u,v)lu。直接暴力转移是 O(n2) 的,考虑优化。

​ 设 du=dis(1,u),那么有 fu=minv{fvdvpu+pudu+qu}。注意到 pudu+qu 是关于 u 的常量,我们真正要关心的就是 dvpu+fv,而这可以看作为一个直线 y=kx+b,其中参数 k=dv,b=fv。于是DP转移方程可以看作为求若干直线在某个点的最小值。但即使这样原问题仍然并不好做。于是我们可以考虑先思考弱化版问题。

  1. 原树是一条链,且DP转移时 v 无距离限制。

    此时的DP状态转移方程为 fu=min1v<u{dvpu+fv}+pudu+qu。这显然是一个斜率优化的式子,可以使用单调栈维护。

  2. 原树是一条链,但DP转移时 v 有距离限制。

    此时的DP状态转移方程不变,但转移范围变了。由于转移范围并不具有单调性,因此并不能直接用单调栈去做。

    一种做法是使用cdq分治,计算跨越隔板的贡献。我们先递归左边,将左半部分的DP值算出来;对于右半部分的点按照转移范围从小到大排序,每次将所有在转移范围内的位置所对应的直线加入到单调栈中,然后查询时在凸包上二分即可。时间复杂度为 O(nlog2n),其中cdq分治有一个 log,在凸包上二分有一个 log

  3. 无特殊条件

    可以考虑将链上的做法拓展到树上。于是我们需要一个树上cdq分治,也就是静态点分治。设当前分治层的分治中心为 z。类似于序列上的做法,我们先递归求出根节点到 z 这一部分的DP值,然后再处理 z 子树内的所有点的贡献。总时间复杂度为 O(nlog2n)

参考代码

#include <bits/stdc++.h>
using namespace std;
template<typename _Tp> _Tp &min_eq(_Tp &x, const _Tp &y) { return x = min(x, y); }
template<typename _Tp> _Tp &max_eq(_Tp &x, const _Tp &y) { return x = max(x, y); }
using real_t = long double;
static constexpr real_t eps = 1e-9;
static constexpr int64_t inf = 0x3f3f3f3f3f3f3f3f;
static constexpr int Maxn = 2e5 + 5;
int n, par[Maxn], head[Maxn], en;
struct Edge { int to, nxt; } e[Maxn * 2];
int64_t dw[Maxn], pw[Maxn], qw[Maxn], lw[Maxn];
inline void add_edge(int u, int v) { e[++en] = (Edge){v, head[u]}, head[u] = en; }
bool visited[Maxn];
int totN, mnN, root, sz[Maxn];
void calcsz(int u, int fa) {
  for (int i = head[u], v; i; i = e[i].nxt)
    if ((v = e[i].to) != fa && !visited[v])
      calcsz(v, u), ++totN;
} // calcsz
void get_rooting(int u, int fa) {
  int mx = 0; sz[u] = 1;
  for (int i = head[u], v; i; i = e[i].nxt)
    if ((v = e[i].to) != fa && !visited[v])
      get_rooting(v, u), sz[u] += sz[v], max_eq(mx, sz[v]);
  if (max_eq(mx, totN - sz[u]) < mnN)
    mnN = mx, root = u;
} // get_rooting
inline int get_root(int u) {
  totN = 1, calcsz(u, 0);
  mnN = totN, root = -1;
  get_rooting(u, 0); return root;
} // get_root
void divide(int u) {
  int fu = u;
  while (fu != 0 && !visited[fu]) fu = par[fu];
  visited[u] = true;
  if (fu != par[u]) divide(get_root(par[u]));
  extern void divide_calc(int, const int&);
  divide_calc(u, fu);
  for (int i = head[u], v; i; i = e[i].nxt)
    if (!visited[v = e[i].to]) divide(get_root(v));
} // divide
int64_t f[Maxn];
int m, U;
struct datum {
  int u;
  int64_t lim;
  datum() { }
  datum(int u, int64_t lim) : u(u), lim(lim) { }
  friend bool operator < (const datum &lhs, const datum &rhs) {
    return lhs.lim < rhs.lim;
  }
} a[Maxn];
void dfs1(int u, int fa) {
  if (lw[u] >= dw[u] - dw[U])
    a[++m] = datum(u, lw[u] - (dw[u] - dw[U]));
  for (int i = head[u], v; i; i = e[i].nxt)
    if ((v = e[i].to) != fa && !visited[v]) dfs1(v, u);
} // dfs1
struct line {
  int64_t k, b;
  line() : k(0LL), b(0LL) { }
  line(int64_t k, int64_t b) : k(k), b(b) { }
  int64_t eval(int64_t x) const { return k * x + b; }
};
inline real_t inter(const line &lhs, const line &rhs) {
  assert(lhs.k != rhs.k);
  return ((real_t)(rhs.b - lhs.b)) / (lhs.k - rhs.k);
} // calc
void divide_calc(int u, const int &Fu) {
  U = u; m = 0;
  for (int i = head[u], v; i; i = e[i].nxt)
    if (!visited[v = e[i].to]) dfs1(v, u);
  sort(a + 1, a + m + 1);
  for (int fu = par[u]; fu != Fu && dw[u] - dw[fu] <= lw[u]; fu = par[fu])
    min_eq(f[u], f[fu] - pw[u] * dw[fu] + pw[u] * dw[u] + qw[u]);
  static line stk[Maxn]; int top = 0;
  auto insert = [&](int64_t k, int64_t b) {
    line t = line(k, b);
    for (; top > 1 && inter(stk[top - 1], stk[top]) < inter(stk[top - 1], t); --top);
    stk[++top] = t;
  };
  for (int i = 1, fu = u; i <= m; ++i) {
    while (fu != Fu && dw[u] - dw[fu] <= a[i].lim)
      insert(-dw[fu], f[fu]), fu = par[fu];
    if (top != 0) {
      int v = a[i].u;
      int low = 1, high = top;
      while (low < high) {
        int mid = (low + high + 1) / 2;
        if (inter(stk[mid - 1], stk[mid]) > pw[v]) low = mid;
        else high = mid - 1;
      }
      min_eq(f[v], stk[low].eval(pw[v]) + pw[v] * dw[v] + qw[v]);
    }
  }
} // divide_calc
int main(void) {
  scanf("%d%*d", &n);
  for (int i = 2; i <= n; ++i) {
    scanf("%d", &par[i]);
    add_edge(par[i], i);
    add_edge(i, par[i]);
    int64_t w; scanf("%lld", &w);
    dw[i] = dw[par[i]] + w;
    scanf("%lld%lld%lld", &pw[i], &qw[i], &lw[i]);
  }
  memset(f, inf, sizeof(f)); f[1] = 0;
  memset(visited, false, sizeof(visited));
  divide(get_root(1));
  for (int i = 2; i <= n; ++i)
    printf("%lld\n", f[i]);
  exit(EXIT_SUCCESS);
} // main

总结 静态点分治其实就是树上cdq分治。

习题 1 Tree

习题 2 Ruri Loves Maschera

习题 3 「BalticOI 2021 Day1」Inside information

习题 4 [USACO12NOV]Balanced Trees G

习题 5 「LibreOJ Round #11」Misaka Network 与 Accelerator

习题 6 NFLSOJ #12755. sumsum

题目描述

给一棵 n 个节点的树,用 1n 的整数表示。每个节点上有一个整数权值 ai。再给出两个整数 L,R。现在有 m 个操作,每个操作这样描述:

  • 给定树上两个节点 u,v 和一个整数 d,表示将树上 uv 唯一的简单路径上每个点的权值 ai 都加上 d。之后求树上所有节点个数大于等于 L 小于等于 R 的简单路径的节点权值和之和。

注意这里有两次求和:对于一条节点个数大于等于 L 小于等于 R 的简单路径,求出它所有节点的权值之和;然后对所有这样的路径,对它们的权值和再进行求和。因为答案很大,只用输出对 Q=109+7 取余的结果即可。

对于 100% 的数据,有 1n,m105

习题 7 [COCI2019] Transport

习题 8 [BJOI2017]树的难题

习题 9 Shopping

习题 10 「JOISC 2020 Day4」首都城市

点分树

​ 点分树是点分治的结构树。具体来说,每次分治的时候,让各个子树的分治中心连接到当前的分治中心上,形成了一个树状结构,这就是点分树。点分树具有以下几条非常实用的性质:

  1. 点分树不一定是二叉树,但其高度是 O(logn)
  2. 对于任意两点 x,y,设 z^ 为点分树上 x,y 的 LCA,则 z^ 一定在原树中 x,y 两点的简单路径上。并且若设 Z^ 为点分树上点 x,y 的公共祖先的集合,则 z^ 是集合 Z^ 中唯一满足这个性质的

例题 1 Forest Game

​ 考虑建立点分树 T,那么每个点的贡献是自己被分治到的次数,也就是点分树 T 上该点的深度。根据期望的线性性,我们分别算出每个点的期望深度,加在一起就是答案。

​ 这里运用一个小 trick:将一个点的深度拆成 depu=vV[lca(u,v)=v]。而在点分树上 vu 的祖先当且仅当原树的路径 (u,v) 上点 v 是第一个被选为分治中心的。设 dis(u,v) 为路径 uv 上的点数,那么上述事件的概率 E[lca(u,v)=v]=1dis(u,v)。于是由期望的线性性,就可以得出 E[depu]=vVE[lca(u,v)=v]=vV1dis(u,v)。所以该题的答案即为 i=1nj=1n1dis(i,j)

​ 考虑如何维护上式。注意这是分式求和,难以批量维护;而分母最多只有 n 种,于是可以考虑记录分母状态。设 ck 为原树上有 ck 条路径长度为 k,那么答案就是 k=1nckk。于是现在只需求出 ck 即可。

​ 可以考虑使用点分治。设当前层的分治中心为 z,记录目前连通块内每个点到 z 的距离,那么使用类似树形背包的方法统计路径,合并两个子树的过程就是一个卷积。于是就做完了。

​ 总时间复杂度为 O(nlog2n)

参考代码
#include <bits/stdc++.h>
using namespace std;
template<typename _Tp> _Tp &min_eq(_Tp &x, const _Tp &y) { return x = min(x, y); }
template<typename _Tp> _Tp &max_eq(_Tp &x, const _Tp &y) { return x = max(x, y); }
static constexpr int mod = 1e9 + 7;
inline int add(int x, int y) { return (x += y) >= mod ? x - mod : x; }
inline int mul(int x, int y) { return (int64_t)x * y % mod; }
inline int &add_eq(int &x, int y) { return x = add(x, y); }
inline int &mul_eq(int &x, int y) { return x = mul(x, y); }
static constexpr int Maxn = 1e5 + 5;
int n, ans[Maxn];
vector<int> g[Maxn];
bool visited[Maxn];
int totN, mnN, root, sz[Maxn];
void get_rooting(int u, int fa) {
  int mx = 0; sz[u] = 1;
  for (const int &v: g[u]) if (!visited[v] && v != fa)
    get_rooting(v, u), sz[u] += sz[v], max_eq(mx, sz[v]);
  if (mnN > max_eq(mx, totN - sz[u]))
    mnN = mx, root = u;
} // get_rooting
inline int get_root(int u, int N) {
  totN = N, mnN = N; root = 0;
  get_rooting(u, 0); return root;
} // get_root
void divide_calc(int);
void divide(int u) {
  divide_calc(u);
  visited[u] = true;
  for (const int &v: g[u]) if (!visited[v])
    divide(get_root(v, sz[v]));
} // divide
namespace conv {
  namespace __fft {
    typedef double real_t;
    struct complex_t {
      real_t x, y;
      complex_t() { x = y = 0; }
      complex_t(real_t x_, real_t y_) : x(x_), y(y_) { }
    };
    inline complex_t operator + (complex_t a, complex_t b) { return complex_t(a.x + b.x, a.y + b.y); }
    inline complex_t operator - (complex_t a, complex_t b) { return complex_t(a.x - b.x, a.y - b.y); }
    inline complex_t operator * (complex_t a, complex_t b) { return complex_t(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); }
    inline complex_t conj(complex_t a) { return complex_t(a.x, -a.y); }
    int __base = 1;
    std::vector<complex_t> __roots = {{0, 0}, {1, 0}};
    std::vector<int> __rev = {0, 1};
    const real_t PI = static_cast<real_t>(acosl(-1.0));
    void ensure_base(int nbase) {
      if (nbase <= __base) return;
      __rev.resize(1 << nbase);
      for (int i = 0; i < (1 << nbase); ++i)
        __rev[i] = (__rev[i >> 1] >> 1) + ((i & 1) << (nbase - 1));
      __roots.resize(1 << nbase);
      while (__base < nbase) {
        real_t angle = 2 * PI / (1 << (__base + 1));
        for (int i = 1 << (__base - 1); i < (1 << __base); ++i) {
          __roots[i << 1] = __roots[i];
          real_t angle_i = angle * (2 * i + 1 - (1 << __base));
          __roots[(i << 1) + 1] = complex_t(cosl(angle_i), sinl(angle_i));
        }
        __base += 1;
      }
    }
    void fft(std::vector<complex_t> &a, int n = -1) {
      if (n == -1) n = static_cast<int>(a.size());
      int zeros = __builtin_ctz(n);
      ensure_base(zeros);
      int shift = __base - zeros;
      for (int i = 0; i < n; ++i)
        if (i < (__rev[i] >> shift))
          std::swap(a[i], a[__rev[i] >> shift]);
      for (int i = 0; i < n; i += 2) {
        complex_t z = a[i + 1] * __roots[1];
        a[i + 1] = a[i] - z;
        a[i] = a[i] + z;
      }
      for (int i = 0; i < n; i += 2 * 2) {
        complex_t z;
        z = a[i + 2] * __roots[2];
        a[i + 2] = a[i] - z;
        a[i] = a[i] + z;
        z = a[i + 1 + 2] * __roots[1 + 2];
        a[i + 1 + 2] = a[i + 1] - z;
        a[i + 1] = a[i + 1] + z;
      }
      for (int k = 4; k < n; k <<= 1) {
        for (int i = 0; i < n; i += 2 * k) {
          complex_t z;
          for (int j = 0; j < k; ) {
#define __j_thread_4                                            \
    z = a[i + j + k] * __roots[j + k];                          \
    a[i + j + k] = a[i + j] - z;                                \
    a[i + j] = a[i + j] + z;                                    \
    ++j;
            __j_thread_4
            __j_thread_4
            __j_thread_4
            __j_thread_4
#undef __j_thread_4
          }
        }
      }
    }
  } // namespace __fft
  std::vector<std::int64_t> __convolution_brute(const std::vector<int>& a, const std::vector<int>& b) {
    int __n = static_cast<int>(a.size());
    int __m = static_cast<int>(b.size());
    std::vector<std::int64_t> __ret(__n + __m - 1);
    for (int __i = 0; __i < __n; __i++)
      for (int __j = 0; __j < __m; __j++)
        __ret[__i + __j] += static_cast<std::int64_t>(a[__i]) * static_cast<std::int64_t>(b[__j]);
    return __ret;
  }
  std::vector<int> __convolution_brute_mod(const std::vector<int>& a, const std::vector<int>& b, int mod) {
    int __n = static_cast<int>(a.size());
    int __m = static_cast<int>(b.size());
    std::vector<int> __ret(__n + __m - 1);
    for (int __i = 0; __i < __n; __i++)
      for (int __j = 0; __j < __m; __j++)
        (__ret[__i + __j] += 1LL * a[__i] * b[__j] % mod) %= mod;
    return __ret;
  }
  std::vector<std::int64_t> square(const std::vector<int>& a) {
    if (a.empty()) return { };
    if (static_cast<int>(a.size()) < 8)
      return __convolution_brute(a, a);
    std::vector<__fft::complex_t> fa, fb;
    int need = static_cast<int>(a.size() + a.size() - 1);
    int nbase = 1;
    while ((1 << nbase) < need) nbase += 1;
    __fft::ensure_base(nbase);
    int sz = 1 << nbase;
    if ((sz >> 1) > static_cast<int>(fa.size()))
      fa.resize(sz >> 1);
    for (int i = 0; i < (sz >> 1); ++i) {
      int x = (2 * i < static_cast<int>(a.size()) ? a[2 * i] : 0);
      int y = (2 * i + 1 < static_cast<int>(a.size()) ? a[2 * i + 1] : 0);
      fa[i] = __fft::complex_t(x, y);
    }
    __fft::fft(fa, sz >> 1);
    __fft::complex_t r(1.0 / (sz >> 1), 0.0);
    for (int i = 0; i <= (sz >> 2); ++i) {
      int j = ((sz >> 1) - i) & ((sz >> 1) - 1);
      __fft::complex_t fe = (fa[i] + __fft::conj(fa[j])) * __fft::complex_t(0.5, 0);
      __fft::complex_t fo = (fa[i] - __fft::conj(fa[j])) * __fft::complex_t(0, -0.5);
      __fft::complex_t aux = fe * fe + fo * fo * __fft::__roots[(sz >> 1) + i] * __fft::__roots[(sz >> 1) + i];
      __fft::complex_t tmp = fe * fo;
      fa[i] = r * (__fft::conj(aux) + __fft::complex_t(0, 2) * __fft::conj(tmp));
      fa[j] = r * (aux + __fft::complex_t(0, 2) * tmp);
    }
    __fft::fft(fa, sz >> 1);
    std::vector<std::int64_t> res(need);
    for (int i = 0; i < need; ++i)
      res[i] = std::llround(i % 2 == 0 ? fa[i >> 1].x : fa[i >> 1].y);
    return res;
  }
  std::vector<std::int64_t> convolution(const std::vector<int> &a, const std::vector<int> &b) {
    if (a.empty() || b.empty()) return { };
    if (a == b) return square(a);
    if (static_cast<int>(std::min(a.size(), b.size())) < 8)
      return __convolution_brute(a, b);
    std::vector<__fft::complex_t> fa, fb;
    int need = static_cast<int>(a.size() + b.size() - 1);
    int nbase = 1;
    while ((1 << nbase) < need) nbase += 1;
    __fft::ensure_base(nbase);
    int sz = 1 << nbase;
    if (sz > static_cast<int>(fa.size())) fa.resize(sz);
    for (int i = 0; i < sz; ++i) {
      int x = (i < static_cast<int>(a.size()) ? a[i] : 0);
      int y = (i < static_cast<int>(b.size()) ? b[i] : 0);
      fa[i] = __fft::complex_t(x, y);
    }
    __fft::fft(fa, sz);
    __fft::complex_t r(0, -0.25 / (sz >> 1));
    for (int i = 0; i <= (sz >> 1); ++i) {
      int j = (sz - i) & (sz - 1);
      __fft::complex_t z = (fa[j] * fa[j] - __fft::conj(fa[i] * fa[i])) * r;
      fa[j] = (fa[i] * fa[i] - __fft::conj(fa[j] * fa[j])) * r;
      fa[i] = z;
    }
    for (int i = 0; i < (sz >> 1); ++i) {
      __fft::complex_t A0 = (fa[i] + fa[i + (sz >> 1)]) * __fft::complex_t(0.5, 0);
      __fft::complex_t A1 = (fa[i] - fa[i + (sz >> 1)]) * __fft::complex_t(0.5, 0) * __fft::__roots[(sz >> 1) + i];
      fa[i] = A0 + A1 * __fft::complex_t(0, 1);
    }
    __fft::fft(fa, sz >> 1);
    std::vector<std::int64_t> res(need);
    for (int i = 0; i < need; ++i)
      res[i] = std::llround(i % 2 == 0 ? fa[i >> 1].x : fa[i >> 1].y);
    return res;
  }
  std::vector<int> convolution_mod(const std::vector<int> &a, const std::vector<int> &b, int mod) {
    if (a.empty() || b.empty()) return { };
    if (static_cast<int>(std::min(a.size(), b.size())) < 8)
      return __convolution_brute_mod(a, b, mod);
    int eq = (a.size() == b.size() && a == b);
    int need = static_cast<int>(a.size() + b.size() - 1);
    int nbase = 0;
    while ((1 << nbase) < need) nbase++;
    __fft::ensure_base(nbase);
    int sz = 1 << nbase;
    std::vector<__fft::complex_t> fa, fb;
    if (sz > (int) fa.size()) fa.resize(sz);
    for (int i = 0; i < (int) a.size(); i++) {
      int x = (a[i] % mod + mod) % mod;
      fa[i] = __fft::complex_t(x & ((1 << 15) - 1), x >> 15);
    }
    std::fill(fa.begin() + a.size(), fa.begin() + sz, __fft::complex_t{0, 0});
    __fft::fft(fa, sz);
    if (sz > static_cast<int>(fb.size())) fb.resize(sz);
    if (eq) {
      std::copy(fa.begin(), fa.begin() + sz, fb.begin());
    }
    else {
      for (int i = 0; i < static_cast<int>(b.size()); i++) {
        int x = (b[i] % mod + mod) % mod;
        fb[i] = __fft::complex_t(x & ((1 << 15) - 1), x >> 15);
      }
      std::fill(fb.begin() + b.size(), fb.begin() + sz, __fft::complex_t{0, 0});
      __fft::fft(fb, sz);
    }
    __fft::real_t ratio = 0.25 / sz;
    __fft::complex_t r2(0, -1);
    __fft::complex_t r3(ratio, 0);
    __fft::complex_t r4(0, -ratio);
    __fft::complex_t r5(0, 1);
    for (int i = 0; i <= (sz >> 1); ++i) {
      int j = (sz - i) & (sz - 1);
      __fft::complex_t a1 = (fa[i] + __fft::conj(fa[j]));
      __fft::complex_t a2 = (fa[i] - __fft::conj(fa[j])) * r2;
      __fft::complex_t b1 = (fb[i] + __fft::conj(fb[j])) * r3;
      __fft::complex_t b2 = (fb[i] - __fft::conj(fb[j])) * r4;
      if (i != j) {
        __fft::complex_t c1 = (fa[j] + __fft::conj(fa[i]));
        __fft::complex_t c2 = (fa[j] - __fft::conj(fa[i])) * r2;
        __fft::complex_t d1 = (fb[j] + __fft::conj(fb[i])) * r3;
        __fft::complex_t d2 = (fb[j] - __fft::conj(fb[i])) * r4;
        fa[i] = c1 * d1 + c2 * d2 * r5;
        fb[i] = c1 * d2 + c2 * d1;
      }
      fa[j] = a1 * b1 + a2 * b2 * r5;
      fb[j] = a1 * b2 + a2 * b1;
    }
    __fft::fft(fa, sz);
    __fft::fft(fb, sz);
    std::vector<int> res(need);
    for (int i = 0; i < need; i++) {
      std::int64_t aa = std::llround(fa[i].x);
      std::int64_t bb = std::llround(fb[i].x);
      std::int64_t cc = std::llround(fa[i].y);
      res[i] = static_cast<int>((aa + ((bb % mod) << 15) + ((cc % mod) << 30)) % mod);
    }
    return res;
  }
} // namespace conv
using conv::convolution;
using conv::convolution_mod;
int buc1[Maxn], buc[Maxn], mxd;
void dfs_calc(int u, int fa, int depth) {
  max_eq(mxd, depth); buc1[depth]++;
  for (const int &v: g[u]) if (!visited[v] && v != fa)
    dfs_calc(v, u, depth + 1);
} // dfs_calc
void divide_calc(int u) {
  vector<int> res{1};
  int mxD = 0; buc[0] = 1;
  add_eq(ans[1], 1);
  for (const int &v: g[u]) if (!visited[v]) {
    mxd = 0; dfs_calc(v, u, 1);
    auto res = convolution_mod(vector(buc, buc + mxD + 1), vector(buc1, buc1 + mxd + 1), mod);
    for (int i = 1; i < (int)res.size(); ++i) add_eq(ans[i + 1], res[i]);
    for (int i = 1; i <= mxd; ++i) buc[i] += buc1[i];
    memset(buc1, 0, sizeof(*buc1) * (mxd + 1));
    max_eq(mxD, mxd);
  }
  memset(buc, 0, sizeof(*buc) * (mxD + 1));
} // divide_calc
int main(void) {
  scanf("%d", &n);
  for (int i = 2; i <= n; ++i) {
    int u, v;
    scanf("%d%d", &u, &v);
    g[u].push_back(v);
    g[v].push_back(u);
  }
  memset(visited, false, sizeof(visited));
  divide(get_root(1, n));
  int Ans = 0;
  for (int i = 1; i <= n; ++i) {
    static int Inv[Maxn];
    Inv[i] = (i == 1 ? 1 : mul(Inv[mod % i], mod - mod / i));
    add_eq(Ans, mul(ans[i] * array{1, 2}[i > 1], Inv[i]));
  }
  for (int i = 1; i <= n; ++i) mul_eq(Ans, i);
  printf("%d\n", Ans);
  exit(EXIT_SUCCESS);
} // main

优化子树合并

​ 很多树上问题的本质是子树合并,直接做的操作复杂度是 O(n2) 的。这时我们就需要优化子树合并的过程。常见的优化子树合并的方式有数据结构维护子树合并树上启发式合并这两种。常用于维护子树合并的数据结构有 线段树、可并堆 等。

树上整体DP

​ 由于很多树形DP的转移是类似于子树合并的过程,因此我们可以利用优化子树合并的方式来优化树形DP转移,这就是树上整体DP的核心思想。大多数树上整体DP使用动态开点线段树维护DP,利用线段树合并来完成DP的转移。

例题 1 Roads in Yusland

​ 考虑DP。设 fu,i 表示结点 u 子树内的边全部被覆盖,且子树内选择的路径可以覆盖到祖先深度为 j 的方案。则有转移

fu,j=min{vCufv,jv[minvCujv=j]}

这是一个树形背包,直接做的时空复杂度都是 O(n2)

​ 考虑使用整体DP优化上述树形背包。我们使用线段树合并维护整体DP,将转移式换一种方式写为

fu,d=minmin{i,j}=d{fu,i+fv,j}=min{minjd{fu,d+fv,j},minid{fu,i+fv,d}}=min{fu,d+minjd{fv,j},fv,d+minid{fu,i}}

这是一个 min 卷积,在线段树合并的同时维护后缀最小值即可。时空复杂度均为 O(nlogn)

参考代码

#include <bits/stdc++.h>
using namespace std;
template<typename _Tp> _Tp &min_eq(_Tp &x, const _Tp &y) { return x = min(x, y); }
template<typename _Tp> _Tp &max_eq(_Tp &x, const _Tp &y) { return x = max(x, y); }
static constexpr int64_t inf = 1e14;
static constexpr int Maxn = 3e5 + 5, MaxN = 7.2005e6 + 5;
int n, m, head[Maxn], en;
struct Edge { int to, nxt; } E[Maxn * 2];
inline void add_edge(int u, int v) { E[++en] = (Edge){v, head[u]}, head[u] = en; }
int head2[Maxn], hn, hv[Maxn], hnxt[Maxn];
int64_t hw[Maxn];
int ls[MaxN], rs[MaxN], tn, root[Maxn];
int64_t tr[MaxN], lz[MaxN];
int stk[MaxN], top;
inline int newnode(void) {
  int nt = (top > 0 ? stk[top--] : ++tn);
  tr[nt] = inf, ls[nt] = rs[nt] = 0, lz[nt] = 0;
  return nt;
} // newnode
inline void pushup(int p) { tr[p] = min({inf, tr[ls[p]], tr[rs[p]]}); }
inline void pushlz(int p, int64_t v) { lz[p] += v, tr[p] += v; }
inline void pushdown(int p) {
  if (lz[p] != 0) {
    if (ls[p]) pushlz(ls[p], lz[p]);
    if (rs[p]) pushlz(rs[p], lz[p]);
    lz[p] = 0;
  }
} // pushdown
void update(int &p, int l, int r, const int &x, const int64_t &w) {
  if (!p) p = newnode();
  if (l == r) return min_eq(tr[p], w), void();
  int mid = (l + r) >> 1; pushdown(p);
  if (x <= mid) update(ls[p], l, mid, x, w);
  else update(rs[p], mid + 1, r, x, w);
  pushup(p);
} // update
void remove(int &p, int l, int r, const int &x) {
  if (!p) return ;
  if (l == r) return tr[p] = inf, void();
  int mid = (l + r) >> 1; pushdown(p);
  if (x <= mid) remove(ls[p], l, mid, x);
  else remove(rs[p], mid + 1, r, x);
  pushup(p);
} // update
int64_t query(int p, int l, int r, int L, int R) {
  if (!p) return inf;
  if (L == l && r == R) return tr[p];
  int mid = (l + r) >> 1; pushdown(p);
  if (R <= mid) return query(ls[p], l, mid, L, R);
  if (L > mid) return query(rs[p], mid + 1, r, L, R);
  return min(query(ls[p], l, mid, L, mid), query(rs[p], mid + 1, r, mid + 1, R));
} // query
int join(int u, int v, int l, int r, int64_t pu, int64_t pv) {
  if (!u && !v) return 0;
  if (!u) return pushlz(v, pv), v;
  if (!v) return pushlz(u, pu), u;
  if (l == r) {
    int p = newnode();
    min_eq(pu, tr[v]); min_eq(pv, tr[u]);
    tr[p] = min(tr[u] + pu, tr[v] + pv);
    min_eq(tr[p], inf);
    stk[++top] = u;
    stk[++top] = v;
    return p;
  }
  int mid = (l + r) >> 1;
  pushdown(u), pushdown(v);
  int p = newnode();
  ls[p] = join(ls[u], ls[v], l, mid, min(pu, tr[rs[v]]), min(pv, tr[rs[u]]));
  rs[p] = join(rs[u], rs[v], mid + 1, r, pu, pv);
  stk[++top] = u;
  stk[++top] = v;
  return pushup(p), p;
} // join
int dep[Maxn];
void dfs(int u, int fa, int depth) {
  dep[u] = depth;
  if (E[head[u]].nxt == 0 && E[head[u]].to == fa) {
    for (int i = head2[u]; i; i = hnxt[i])
      update(root[u], 1, n, dep[hv[i]], hw[i]);
    if (!root[u]) root[u] = newnode();
  } else {
    for (int i = head[u], v; i; i = E[i].nxt) if ((v = E[i].to) != fa) {
      dfs(v, u, depth + 1);
      if (root[u] == 0) root[u] = root[v];
      else root[u] = join(root[u], root[v], 1, n, inf, inf);
    }
    for (int i = head2[u]; i; i = hnxt[i])
      update(root[u], 1, n, dep[hv[i]], min(hw[i] + query(root[u], 1, n, dep[hv[i]], dep[u]), inf));
    if (u != 1) remove(root[u], 1, n, dep[u]);
  }
} // dfs
int main(void) {
  scanf("%d%d", &n, &m); tr[0] = inf;
  if (n == 1) return puts("0"), 0;
  en = 0, memset(head, 0, sizeof(head));
  for (int i = 2, u, v; i <= n; ++i)
    scanf("%d%d", &u, &v), add_edge(u, v), add_edge(v, u);
  hn = 0, memset(head2, 0, sizeof(head2));
  for (int i = 1, u, v; i <= m; ++i) {
    int64_t w; scanf("%d%d%lld", &u, &v, &w);
    if (u != v) hv[++hn] = v, hw[hn] = w, hnxt[hn] = head2[u], head2[u] = hn;
  }
  dfs(1, 0, 1);
  int64_t ans = query(root[1], 1, n, dep[1], dep[1]);
  printf("%lld\n", ans >= inf ? -1 : ans);
  exit(EXIT_SUCCESS);
} // main

例题 2 「CEOI2019」魔法树

​ 考虑DP。设 fu,t 表示恰好在 t 时刻剪断结点 u 与其父亲的边可获得的最大收益。先不考虑结点 u 自身的贡献,在合并两个儿子 u,v 时的转移方程应为

fu,t=maxmax{i,j}=t{fu,i+fv,j}=max{fu,t+max1itfv,i,fv,t+max1itfu,i}

然后再考虑结点 u 自身的贡献,有 fu,du=max1idu{fu,i}+wu。于是我们得到了一个时空复杂度均为 O(n2) 的做法。

​ 考虑使用整体DP优化背包转移。由于转移是 max 卷积,因此在线段树合并的同时维护前缀最大值即可。时空复杂度均为 O(nlogn)

参考代码

#include <bits/stdc++.h>
using namespace std;
static constexpr int Maxn = 1e5 + 5;
int n, m, k;
vector<int> g[Maxn];
int d[Maxn];
int64_t w[Maxn];
struct treedot {
  int ls, rs;
  int64_t val;
  int64_t lz;
} tr[Maxn << 5];
int tot, root[Maxn];
void pushup(int p, int l, int r) {
  tr[p].val = max(tr[tr[p].ls].val, tr[tr[p].rs].val);
} // pushup
void apply(int p, int l, int r, int64_t v) {
  if (!p) return ;
  tr[p].val += v, tr[p].lz += v;
} // apply
void pushdown(int p, int l, int r) {
  if (!p) return ;
  int mid = (l + r) >> 1;
  if (tr[p].lz != 0) {
    apply(tr[p].ls, l, mid, tr[p].lz);
    apply(tr[p].rs, mid + 1, r, tr[p].lz);
    tr[p].lz = 0;
  }
} // pushdown
void modify(int &p, int l, int r, int x, int64_t v) {
  if (!p) p = ++tot;
  if (l == r) {
    tr[p].val = max(tr[p].val, v);
  } else {
    int mid = (l + r) >> 1;
    pushdown(p, l, r);
    if (x <= mid) modify(tr[p].ls, l, mid, x, v);
    else modify(tr[p].rs, mid + 1, r, x, v);
    pushup(p, l, r);
  }
} // modify
int64_t query(int p, int l, int r, int L, int R) {
  if (!p || L > r || l > R) return 0;
  if (L <= l && r <= R) return tr[p].val;
  int mid = (l + r) >> 1;
  pushdown(p, l, r);
  return max(query(tr[p].ls, l, mid, L, R), query(tr[p].rs, mid + 1, r, L, R));
} // query
int join(int u, int v, int l, int r, int64_t pu, int64_t pv) {
  if (!u && !v) return 0;
  if (!u) { apply(v, l, r, pv); return v; }
  if (!v) { apply(u, l, r, pu); return u; }
  if (l == r) {
    pu = max(pu, tr[v].val);
    pv = max(pv, tr[u].val);
    tr[u].val = max(tr[u].val + pu, tr[v].val + pv);
    return u;
  }
  int mid = (l + r) >> 1;
  pushdown(u, l, r); pushdown(v, l, r);
  int64_t lu_val = tr[tr[u].ls].val, lv_val = tr[tr[v].ls].val;
  tr[u].ls = join(tr[u].ls, tr[v].ls, l, mid, pu, pv);
  tr[u].rs = join(tr[u].rs, tr[v].rs, mid + 1, r, max(pu, lv_val), max(pv, lu_val));
  pushup(u, l, r);
  return u;
} // join
void dfs(int u, int fa) {
  for (const int &v: g[u]) if (v != fa) {
    dfs(v, u);
    root[u] = join(root[u], root[v], 1, k, 0LL, 0LL);
  }
  if (d[u] != 0) {
    int64_t W = query(root[u], 1, k, 1, d[u]);
    modify(root[u], 1, k, d[u], W + w[u]);
  }
} // dfs
int main(void) {
  scanf("%d%d%d", &n, &m, &k);
  for (int i = 2, pi; i <= n; ++i) {
    scanf("%d", &pi);
    g[pi].push_back(i);
    g[i].push_back(pi);
  }
  for (int i = 1, v; i <= m; ++i) {
    scanf("%d", &v);
    scanf("%d%lld", &d[v], &w[v]);
  }
  dfs(1, 0);
  printf("%lld\n", tr[root[1]].val);
  exit(EXIT_SUCCESS);
} // main

习题 1 「FJOI2018」领导集团问题

习题 2 「NOI2020」命运

习题 3 「PKUWC2018」Minimax

posted @   cutx64  阅读(18)  评论(1编辑  收藏  举报
相关博文:
阅读排行:
· 单线程的Redis速度为什么快?
· 展开说说关于C#中ORM框架的用法!
· Pantheons:用 TypeScript 打造主流大模型对话的一站式集成库
· SQL Server 2025 AI相关能力初探
· 为什么 退出登录 或 修改密码 无法使 token 失效
点击右上角即可分享
微信分享提示