P4719 【模板】"动态 DP"&动态树分治

知识点:树形 DP,矩阵乘法,重链剖分,线段树

原题面:Luogu

宣传一波:「笔记」广义矩阵乘法与 DP

简述

给定一棵 \(n\) 个点的树,点有点权。给定 \(m\) 次点权修改操作,求每次操作后整棵树的 最大点权独立集 的权值。
一棵树的独立集定义为满足任意一条边的两端点都不同时存在于集合中的树的一个点集,一个独立集的价值定义为集合中所有点的点权之和。
\(1\le n,m\le 10^5\)\(-100\le\) 点权 \(\le 100\)
1S,256MB。

分析

前置知识

以下简单介绍广义矩阵乘法。

对于一 \(p\times m\) 的矩阵 \(A\),与 \(m\times q\) 的矩阵 \(B\),定义广义矩阵乘法 \(A\times B = C\) 的结果是一个 \(p\times q\) 的矩阵 \(C\),满足:

\[C_{i,j} = (A_{i, 1}\otimes B_{1,j}) \oplus (A_{i,2}\otimes B_{2,j})\oplus \cdots \oplus (A_{i, n}\otimes B_{n,j}) \]

其中 \(\oplus\)\(\otimes\) 是两种二元运算。

考察这种广义矩阵乘法是否满足结合律:

\[\begin{aligned} ((AB)C)_{i,j} &= \bigoplus_{k=1}^{p}(AB)_{i,k}\otimes C_{k,j}\\ &= \bigoplus_{k=1}^{p}\left( \left(\bigoplus_{t=1}^{q} A_{i,t}\otimes B_{t,k}\right) \otimes C_{k,j}\right)\\ (A(BC))_{i,j} &= \bigoplus_{t=1}^{q}A_{i,t}\otimes (BC)_{t,j}\\ &= \bigoplus_{t=1}^{q} \left(A_{i,t}\otimes \left(\bigoplus_{k=1}^{p} B_{t,k} \otimes C_{k,j}\right)\right) \end{aligned}\]

显然,\(\otimes\) 运算满足交换律、结合律,且 \(\otimes\)\(\oplus\) 存在分配律,即存在 \(\left(\bigoplus a\right)\otimes b = \bigoplus \left( a\otimes b \right)\) 时,广义矩阵乘法满足结合律。根据上述运算规律,对二式进行 \(\oplus\) 的交换后有:

\[((AB)C)_{i,j} = (A(BC))_{i,j} = \bigoplus_{k=1}^{p}\bigoplus_{t=1}^{q} \left(A_{i,t}\otimes B_{t,k}\otimes C_{k,j}\right) \]

回到此题

先考虑朴素 DP。钦定 1 为根,设 \(f_{u,0/1}\) 表示钦定点 \(u\) 不在/在 独立集时以 \(u\) 为根的子树的最大点权独立集的权值,显然有:

\[\begin{aligned} f_{u,0} &= \sum_{v\in \mathbf{son}(u)} \max(f_{v, 0}, f_{v,1})\\ f_{u,1} &= \operatorname{val}_u + \sum_{v\in \operatorname{son}(u)} f_{v, 0} \end{aligned}\]

答案即为 \(\max (f_{1,0}, f_{1, 1})\)

要求支持修改,又树的形态不变,考虑用树链剖分维护。但发现每个节点的 DP 值与其所有儿子有关,而树剖只能支持修改重链/子树信息。于是考虑对于每个节点,先将其轻儿子的贡献求和,再考虑其重儿子的贡献,使得可以通过对重链的修改/查询来维护上述信息。这种思想在 LCT 维护子树信息时也有所应用。
\(g_{u,0/1}\) 表示钦定 \(u\) 的重儿子不在独立集,点 \(u\) 不在/在 独立集时以 \(u\) 为根的子树的最大点权独立集的权值。记 \(\operatorname{s}_u\) 表示 \(u\) 的重儿子,显然有:

\[\begin{aligned} g_{u,0} &= \sum_{v\in \operatorname{son}(u) \land v\not= \operatorname{s}_u}\max(f_{v,0}, f_{v,1})\\ g_{u,1} &= \operatorname{val}_u + \sum_{v\in \operatorname{son}(u) \land v\not = \operatorname{s}_u} f_{v, 0} \end{aligned}\]

则对 \(f\) 的转移可以改写成下列形式:

\[\begin{aligned} f_{u,0} &= g_{u,0} + \max(f_{\operatorname{s}_u, 0}, f_{\operatorname{s}_u,1})\\ f_{u,1} &= g_{u,1} + f_{\operatorname{s}_u, 0} \end{aligned}\]

考虑加法运算运算与取 \(\max\) 运算的性质:发现取 \(\max\) 满足交换律与结合律,且加法对取 \(\max\) 满足分配率,即有:

\[a + \max_{i}(b_i) = \max_{i}(a + b_i) \]

考虑定义一种广义矩阵乘法 \(A\times B = C\),满足:

\[C_{i,j} = \max_{k}\left( A_{i,k} +B_{k,j}\right) \]

根据上述转移方程,有下列关系成立。

\[\large \begin{bmatrix} g_{u,0}& g_{u,0}\\ g_{u,1}& -\infin \end{bmatrix} \times \begin{bmatrix} f_{\operatorname{s}_u, 0}\\ f_{\operatorname{s}_u, 1} \end{bmatrix} = \begin{bmatrix} f_{u, 0}\\ f_{u, 1} \end{bmatrix}\]

于是可以考虑先预处理出 \(g\) 数组初始化转移矩阵,再使用线段树维护区间矩阵乘积。转移矩阵写在前面是因为 dfs 序列中深度较浅的点在前,转移矩阵写在前面可以直接按 dfs 序求得区间矩阵乘积并转移。若转移矩阵写在后面,需要先将区间内的元素顺序反转。经过预处理后,求得以 1 为根的重链对应区间的矩阵乘积,即得 \(f_{u,0}\)\(f_{u,1}\)。正确性显然,重链一定以某叶节点为链底,以 1 为根的重链上所有轻儿子子树信息的并即为整棵树的信息。

考虑修改操作对哪些位置的 \(g\) 会产生影响。考虑其实际含义,\(g\) 维护的是轻儿子子树信息。被影响的节点显然为指定的修改位置 \(x\),以及子树中包含被修改位置,且为轻儿子的节点的父亲,后者可以通过从被修改位置不断跳重链来进行遍历。每次跳到的重链的顶的父亲,即为对应节点。
每次更新上述节点时先求得修改前以该节点的对应轻儿子的子树信息,修改子树中的节点后再求得该节点的对应轻儿子子树信息。根据两次求得的子树信息的差更新该节点的 \(g\),并将即将被修改的节点调整为当前节点。建议结合代码理解。

总复杂度 \(O(8n\log n + 8m\log^2 n)\) 级别。

代码

//知识点:重链剖分,线段树,树形 DP
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#define LL long long
const int kN = 1e5 + 10;
const int kL = 2;
const int kInf = 1e9 + 2077;
//=============================================================
int n, m, e_num, head[kN], val[kN], v[kN << 1], ne[kN << 1];
int dfn_num, dfn[kN], id[kN], f[kN][2], g[kN][2];
//=============================================================
inline int read() {
  int f = 1, w = 0;
  char ch = getchar();
  for (; !isdigit(ch); ch = getchar())
    if (ch == '-') f = -1;
  for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
  return f * w;
}
void Chkmax(int &fir, int sec) {
  if (sec > fir) fir = sec;
}
void Chkmin(int &fir, int sec) {
  if (sec < fir) fir = sec;
}
struct Matrix {
  int a[kL][kL];
  Matrix() {
    memset(a, 0, sizeof (a));
  }
  void build() {
    for (int i = 1; i <= kL; ++ i) a[i][i] = 1;
  }
  Matrix operator * (const Matrix &b_) const {
    Matrix ret;
    memset(ret.a, 128, sizeof (ret.a));
    for (int k = 0; k < kL; ++ k) {
      for (int i = 0; i < kL; ++ i) {
        for (int j = 0; j < kL; ++ j) {
          ret.a[i][j] = std::max(ret.a[i][j], a[i][k] + b_.a[k][j]);
        }
      }
    }
    return ret;
  }
} matrix[kN];
void Add(int u_, int v_) {
  v[++ e_num] = v_;
  ne[e_num] = head[u_];
  head[u_] = e_num;
} 
namespace Seg { //维护区间矩阵乘积
  #define ls (now_<<1)
  #define rs (now_<<1|1)
  #define mid ((L_+R_)>>1)
  Matrix sum[kN << 2];
  void Pushup(int now_) {
    sum[now_] = sum[ls] * sum[rs];
  }
  void Build(int now_, int L_, int R_) {
    if (L_ == R_) {
      sum[now_] = matrix[L_];
      return ;
    }
    Build(ls, L_, mid), Build(rs, mid + 1, R_);
    Pushup(now_); 
  }
  void Modify(int now_, int L_, int R_, int pos_) {
    if (L_ == R_) {
      sum[now_] = matrix[pos_];
      return ;
    }
    if (pos_ <= mid) Modify(ls, L_, mid, pos_);
    else Modify(rs, mid + 1, R_, pos_);
    Pushup(now_);
  }
  Matrix Query(int now_, int L_, int R_, int l_, int r_) {
    if (l_ == L_ && R_ == r_) return sum[now_];
    if (r_ <= mid) return Query(ls, L_, mid, l_, r_);
    if (l_ > mid) return Query(rs, mid + 1, R_, l_, r_);
    return Query(ls, L_, mid, l_, mid) * Query(rs, mid + 1, R_, mid + 1, r_);
  }
  #undef ls
  #undef rs
  #undef mid
}
namespace HLD {
  int fa[kN], sz[kN], son[kN], dep[kN], top[kN], end[kN];
  void Dfs1(int u_, int fa_) {
    sz[u_] = 1;
    fa[u_] = fa_;
    f[u_][1] = val[u_];
    dep[u_] = dep[fa_] + 1;
    
    for (int i = head[u_]; i; i = ne[i]) { //预处理 f
      int v_ = v[i];
      if (v_ == fa_) continue ;
      Dfs1(v_, u_);
      sz[u_] += sz[v_];
      if (sz[v_] > sz[son[u_]]) son[u_] = v_;
      f[u_][0] += std::max(f[v_][0], f[v_][1]);
      f[u_][1] += f[v_][0];
    }
  }
  void Dfs2(int u_, int top_) {
    dfn[u_] = ++ dfn_num;
    id[dfn_num] = u_;
    top[u_] = top_;
    Chkmax(end[top_], dfn_num);

    if (son[u_]) Dfs2(son[u_], top_);
    g[u_][1] = val[u_];
    for (int i = head[u_]; i; i = ne[i]) { //预处理 g
      int v_ = v[i];
      if (v_ == fa[u_] || v_ == son[u_]) continue ;
      Dfs2(v_, v_);
      g[u_][0] += std::max(f[v_][0], f[v_][1]);
      g[u_][1] += f[v_][0];
    }
  }
  void Modify(int u_, int val_) {
    matrix[dfn[u_]].a[1][0] += val_ - val[u_]; //修改 u_ 的 g[u_][1]
    val[u_] = val_; //更新点权
    while (u_) { //u_ 不断上跳
      Matrix old = Seg::Query(1, 1, n, dfn[top[u_]], end[top[u_]]); //以 top[u_] 为根的子树的信息
      Seg::Modify(1, 1, n, dfn[u_]); //修改节点 u_ 的信息(单点修改矩阵)
      Matrix newone = Seg::Query(1, 1, n, dfn[top[u_]], end[top[u_]]); //更新后以 top[u_] 为根的子树的信息
      u_ = fa[top[u_]]; //更新轻儿子 u_ 的父亲的 g
      //注意下文的赋值还未更新到线段树上,上面需要求得未修改之前的信息,更新线段树信息要在之后进行
      matrix[dfn[u_]].a[0][0] += std::max(newone.a[0][0], newone.a[1][0]) - 
                                 std::max(old.a[0][0], old.a[1][0]);
      matrix[dfn[u_]].a[0][1] = matrix[dfn[u_]].a[0][0];
      matrix[dfn[u_]].a[1][0] += newone.a[0][0] - old.a[0][0];
    }
  }
  int Query() { //求得以 1 为根的重链对应区间的矩阵乘积,即得答案
    //重链一定以某叶节点为链底,以 1 为根的重链上所有轻儿子子树信息的并即为整棵树的信息。
    Matrix ans = Seg::Query(1, 1, n, 1, end[1]);
    return std::max(ans.a[0][0], ans.a[1][0]);
  }
}
//=============================================================
int main() {
  n = read(), m = read();
  for (int i = 1; i <= n; ++ i) val[i] = read();
  for (int i = 1; i < n; ++ i) {
    int u_ = read(), v_ = read();
    Add(u_, v_), Add(v_, u_);
  }
  HLD::Dfs1(1, 0), HLD::Dfs2(1, 1);
  for (int i = 1; i <= n; ++ i) { //构造转移矩阵
    matrix[dfn[i]].a[0][0] = matrix[dfn[i]].a[0][1] = g[i][0];
    matrix[dfn[i]].a[1][0] = g[i][1], matrix[dfn[i]].a[1][1] = -kInf;
  }
  Seg::Build(1, 1, n);
  while (m --) {
    int x_ = read(), y_ = read();
    HLD::Modify(x_, y_);
    printf("%d\n", HLD::Query());
  }
  return 0; 
}
posted @ 2021-02-23 19:55  Luckyblock  阅读(104)  评论(0编辑  收藏  举报