LOJ #6733. 人造情感

LOJ #6733. 人造情感

​ 先考虑如何求解 W(S)。设 fu 为考虑子树 u 内的路径集合的 W 值,则有转移

fu=max{vchufv}{w+vpath(x,y)favpath(x,y)fv}

其中 (x,y,w) 在给定的路经集合 U 中,且满足条件 lca(x,y)=u。直接做显然就暴毙了,可以思考转移的性质。

​ 观察发现这个转移涉及到“路径下方所挂的点”的求和,于是我们可以想到树上差分。令 fu=fuvchufv,在上式中用 f 替代 f,于是转移就变成了:

fu=max{0}{wvpath(x,y)vufv}

其中 (x,y,w) 在给定的路经集合 U 中,且满足条件 lca(x,y)=u

​ 于是现在只需实现单点加链求和就行了,可以用树状数组维护 dfn 序,时间复杂度为 O(nlogn)

​ 接下来考虑如何计算 f(x,y)。注意到 f(x,y)=froothlca(x,y)upath(x,y)faupath(x,y)fu,其中 hu 为考虑子树 u 外的路径集合的 W 值。后面涉及到“路径下方所挂的点”的求和可以直接用树上差分消掉,有 f(x,y)=froot+upath(x,y)fuflca(u,v)hlca(u,v)。前面有关 f,f 的值是很容易就能求出来的,关键就是如何计算 hu

​ 类似于处理 f 的方法,我们设 hu=huhfauvchfauvufv,那么 h 有和 f 类似的转移式。我们考虑枚举所有经过 fau 但不经过 u 的路径 (x,y,w),设 z=lca(x,y)。对于当前枚举的 u,设 gv={fv,vuhv,vu,容易发现路径 (x,y,w)hu 的贡献为 wvpath(x,y)vzgv。于是我们就得到了一个朴素的做法:先枚举结点 u,然后枚举路径集合 U 中经过 u 的路径 (x,y,w),在更新 g 的同时计算贡献。该做法的时间复杂度为 O(n2)

​ 考虑优化。注意到 U 中的路径 (x,y,w) 只会对其下方挂着的所有点造成贡献;更进一步,其对同一个结点下挂着的所有儿子的贡献是一样的。设结点 u 在路径 (x,y) 上,结点 v 为结点 u 的一个儿子,且满足 v 不在路径 (x,y) 上。设 z=lca(x,y),那么该路径对 v 的贡献为

{w(FxFz)(FyFz),(u=z)w(FxFu)(FyFz)(HuHz),(upath(x,z)uz)w(FxFz)(FyFu)(HuHz),(upath(y,z)uz)

其中 F,H 分别为 f,h 在树上的祖先和,即有 Fx=ypath(root,x)fy。上式中的第二类和第三类可以合在一起讨论,因此只需要讨论 u=zuz 的情况。

  • u=z 此时路径 (x,y,w)z 的子结点 v 有贡献当且仅当 v 不在路径 (x,y) 上。将 z 的所有子结点看成一个序列,那么该路径在序列上的影响可分为 1/2/3 个区间。子结点序列上需要实现的操作是区间取 max 单点查询。因此只需要使用线段树维护 z 的子结点序列即可。
  • uz 不妨设 upath(x,z)。此时路径 (x,y,w)u 的子结点 v 有贡献当且仅当 v 不在路径 (x,y) 上。考虑枚举结点 u 时计算 hv,那么对 hv 有贡献的路径 (x,y,w) 一定存在端点 xu 子树中而不在 v 子树中。我们可以在结点 x,y 上存下路径 (x,y,w) 的贡献,查询时转到 dfn 序上就是一个区间查询。因此我们可以用线段树维护 dfn 序。

时间复杂度为 O((n+m)logn)

参考代码

#include <bits/stdc++.h>
using namespace std;
template<typename _Tp> _Tp &min_eq(_Tp &x, const _Tp &y) { return x = min(x, y); }
template<typename _Tp> _Tp &max_eq(_Tp &x, const _Tp &y) { return x = max(x, y); }
static constexpr int mod = 998244353;
static constexpr int Maxn = 3e5 + 5;
static constexpr int64_t inf = 0x3f3f3f3f3f3f3f3f;
int n, m;
int64_t ans;
vector<int> g[Maxn];
namespace hld {
  int par[Maxn], sz[Maxn], son[Maxn], dep[Maxn];
  int top[Maxn], dfn[Maxn], idfn[Maxn], ed[Maxn], dn;
  void predfs1(int u, int fa, int depth) {
    par[u] = fa, dep[u] = depth; sz[u] = 1, son[u] = 0;
    for (const int &v: g[u]) if (v != par[u]) {
      predfs1(v, u, depth + 1), sz[u] += sz[v];
      if (son[u] == 0 || sz[v] > sz[son[u]]) son[u] = v;
    }
  } // hld::predfs1
  void predfs2(int u, int topv) {
    top[u] = topv, idfn[dfn[u] = ++dn] = u;
    if (son[u] != 0) predfs2(son[u], topv);
    for (const int &v: g[u]) if (v != par[u])
      if (v != son[u]) predfs2(v, v);
  } // hld::predfs2
  inline int get_lca(int u, int v) {
    for (; top[u] != top[v]; v = par[top[v]])
      if (dep[top[u]] > dep[top[v]]) swap(u, v);
    return dep[u] < dep[v] ? u : v;
  } // hld::get_lca
  inline int get_anc(int u, int k) {
    if (k < 0 || k >= dep[u]) return 0;
    for (; dep[u] - dep[par[top[u]]] <= k; u = par[top[u]])
      k -= (dep[u] - dep[par[top[u]]]);
    return idfn[dfn[u] - k];
  } // hld::get_anc
  inline void initialize(int root) {
    dn = 0, predfs1(root, 0, 1), predfs2(root, root);
    for (int i = 1; i <= n; ++i) ed[i] = dfn[i] + sz[i] - 1;
  } // hld::initialize
} // namespace hld
using namespace hld;
namespace fen {
  int64_t b[Maxn];
  inline void clr(void) { memset(b, 0, sizeof(b)); }
  inline void upd(int x, int64_t v) { for (; x <= n; x += x & -x) b[x] += v; }
  inline int64_t ask(int x) { int64_t r = 0; for (; x; x -= x & -x) r += b[x]; return r; }
} // namespace fen
struct path { int x, y; int64_t w; } pa[Maxn];
vector<path> plca[Maxn];
int64_t f[Maxn], F[Maxn], f1[Maxn];
void dfs1(int u, int fa) {
  for (const int &v: g[u]) if (v != fa) dfs1(v, u);
  for (const auto &[x, y, w]: plca[u])
    max_eq(f[u], w - fen::ask(dfn[x]) - fen::ask(dfn[y]));
  fen::upd(dfn[u], f[u]), fen::upd(ed[u] + 1, -f[u]);
} // dfs1
void dfs11(int u, int fa) {
  F[u] = f[u] + F[fa], f1[u] = f[u];
  for (const int &v: g[u]) if (v != fa)
    dfs11(v, u), f1[u] += f1[v];
} // dfs11
int64_t h[Maxn], H[Maxn], h1[Maxn];
namespace sgt1 {
  int64_t tr[Maxn * 4];
  void update(int p, int l, int r, int x, int64_t v) {
    max_eq(tr[p], v);
    if (l == r) return ; int mid = (l + r) / 2;
    if (x <= mid) update(p * 2 + 0, l, mid, x, v);
    else update(p * 2 + 1, mid + 1, r, x, v);
  } // sgt1::update
  int64_t query(int p, int l, int r, int L, int R) {
    if (L > r || l > R) return -inf;
    if (L <= l && r <= R) return tr[p];
    int mid = (l + r) / 2;
    return max(query(p * 2 + 0, l, mid, L, R), query(p * 2 + 1, mid + 1, r, L, R));
  } // sgt1::query
  inline void upd(int x, int64_t v) { return update(1, 1, n, x, v); }
  inline int64_t ask(int l, int r) { return query(1, 1, n, l, r); }
} // namespace sgt1
namespace sgt2 {
  int N; int64_t tr[Maxn * 4];
  void build(int n) { N = n, memset(tr, -63, (n + 1) * 4 * sizeof(*tr)); }
  void update(int p, int l, int r, int L, int R, int64_t v) {
    if (L > r || l > R) return ;
    if (L <= l && r <= R) return max_eq(tr[p], v), void();
    int mid = (l + r) / 2;
    update(p * 2 + 0, l, mid, L, R, v);
    update(p * 2 + 1, mid + 1, r, L, R, v);
  } // sgt2::update
  int64_t query(int p, int l, int r, int x) {
    if (l == r) return tr[p];
    int mid = (l + r) / 2; int64_t t = tr[p];
    if (x <= mid) max_eq(t, query(p * 2 + 0, l, mid, x));
    else max_eq(t, query(p * 2 + 1, mid + 1, r, x));
    return t;
  } // sgt2::query
  inline void upd(int l, int r, int64_t v) { return update(1, 1, N, l, r, v); }
  inline int64_t ask(int x) { return query(1, 1, N, x); }
} // namespace sgt2
void dfs2(int u, int fa) {
  H[u] = H[fa] + h[u];
  static int label[Maxn]; int N = 0;
  for (const int &v: g[u]) if (v != fa) label[v] = ++N;
  if (N != 0) {
    for (const int &v: g[u]) if (v != fa)
      max_eq(h[v], max(sgt1::ask(dfn[u], dfn[v] - 1), sgt1::ask(ed[v] + 1, ed[u])) + F[u] - H[u]);
    sgt2::build(N);
    for (auto [x, y, w]: plca[u]) {
      if (dep[x] > dep[y]) swap(x, y);
      int xk = get_anc(x, dep[x] - dep[u] - 1);
      int yk = get_anc(y, dep[y] - dep[u] - 1);
      if (xk == 0 && yk == 0) {
        sgt2::upd(1, N, w);
      } else if (xk == 0) {
        int64_t v = w - F[y] + F[u];
        if (1 < label[yk]) sgt2::upd(1, label[yk] - 1, v);
        if (label[yk] < N) sgt2::upd(label[yk] + 1, N, v);
      } else {
        int64_t v = w - F[x] + F[u] - F[y] + F[u];
        if (label[xk] > label[yk]) swap(x, y), swap(xk, yk);
        if (1 < label[xk]) sgt2::upd(1, label[xk] - 1, v);
        if (label[yk] < N) sgt2::upd(label[yk] + 1, N, v);
        if (label[xk] + 1 <= label[yk] - 1) sgt2::upd(label[xk] + 1, label[yk] - 1, v);
      }
    }
    for (const int &v: g[u]) if (v != fa)
      max_eq(h[v], sgt2::ask(label[v]));
  }
  for (auto [x, y, w]: plca[u]) {
    if (dep[x] > dep[y]) swap(x, y);
    int xk = get_anc(x, dep[x] - dep[u] - 1);
    int yk = get_anc(y, dep[y] - dep[u] - 1);
    if (xk == 0 && yk == 0) {
    } else if (xk == 0) {
      int64_t v = w - F[y] + H[u];
      sgt1::upd(dfn[y], v);
    } else {
      int64_t v = w - F[x] - F[y] + H[u] + F[u];
      sgt1::upd(dfn[x], v);
      sgt1::upd(dfn[y], v);
    }
  }
  for (const int &v: g[u]) if (v != fa) dfs2(v, u);
} // dfs2
void dfs3(int u, int fa) {
  for (const int &v: g[u]) if (v != fa)
    h1[v] = h1[u] + f1[u] - f1[v] - f[u] + h[v], dfs3(v, u);
} // dfs3
int64_t sf[Maxn];
void dfs4(int u, int fa) {
  for (const int &v: g[u]) if (v != fa) dfs4(v, u), (sf[u] += sf[v]) %= mod;
  int64_t z = (int64_t)sz[u] * sz[u];
  for (const int &v: g[u]) if (v != fa) z -= (int64_t)sz[v] * sz[v];
  ((ans -= (z % mod) * ((f1[u] + h1[u]) % mod) % mod) += mod) %= mod;
  for (const int &v: g[u]) if (v != fa) (ans += 2 * sf[v] * (sz[u] - sz[v]) % mod) %= mod;
  (ans += (f[u] % mod) * (z % mod) % mod) %= mod;
  (sf[u] += sz[u] * f[u] % mod) %= mod;
} // dfs4
int main(void) {
  scanf("%d%d", &n, &m);
  for (int i = 2; i <= n; ++i) {
    int u, v;
    scanf("%d%d", &u, &v);
    g[u].push_back(v);
    g[v].push_back(u);
  }
  hld::initialize(1);
  for (int i = 1; i <= m; ++i) {
    scanf("%d%d%lld", &pa[i].x, &pa[i].y, &pa[i].w);
    int z = get_lca(pa[i].x, pa[i].y);
    plca[z].push_back(pa[i]);
  }
  fen::clr(), dfs1(1, 0), dfs11(1, 0);
  memset(sgt1::tr, -63, sizeof(sgt1::tr));
  dfs2(1, 0), dfs3(1, 0);
  ans = 0, dfs4(1, 0);
  (ans += (int64_t)n * n % mod * (f1[1] % mod) % mod) %= mod;
  printf("%lld\n", ans);
  exit(EXIT_SUCCESS);
} // main

posted @   cutx64  阅读(142)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】
点击右上角即可分享
微信分享提示