[解题报告] 「BJOI2017」树的难题(点分治 + 线段树/单调队列)

传送🚪

题意

一个点数为 \(n\) 的树,每个点有颜色 \(c_i\),每个颜色有权值 \(v_i\)

一条路径的权值定义为该条路径上各同色连通块的权值之和。

例如,一条颜色序列为 1 2 2 1 的路径,其权值为 \(v_1 + v_2 + v_1\)

求长度在 \([L,R]\) 之前的路径的权值最大值。

\(n \le 2 \times 10^5, c \le n\)

解法

解法一 线段树

树上和距离有关的问题首先考虑点分治。对于每个节点 u,

把它的儿子按照颜色排序,让颜色相同的儿子放在一起。

点分治时开两棵线段树,分别记录与当前颜色相同 / 不同的最大值,遍历完该种颜色后把贡献加到第二棵线段树上。

时间复杂度为 \(O(n\log^2 n)\)

解法二 按秩合并单调队列

考虑直接把路径扣出来,然后用单调队列处理(不同颜色分开处理,相同颜色的就按照子树分开处理)。

但是由于这里单调队列会有个初始化的复杂度,就是把取值区间的左端点从 \(maxdis\) 移动到 \(L\)(其中 \(maxdis\) 表示之前遍历过的子树的最大深度),这样复杂度可以被卡到 \(O(n^2)\)

有一个叫 “按秩合并单调队列” 的做法。

就是把扣出来的路径排序,不同颜色之间按照该颜色中路径长度最大值从小到大排序,相同颜色的路径也按照长度从小到大排序,这样单调队列的复杂度就是对的(可以看做把之前的路径长度带来的复杂度算到自己身上,然后初始化的总复杂度就是 \(O(sz)\) 的),复杂度是 \(O(n \log n)\) 的。

解法三 带回溯单调队列

当然如果像我一样蠢,可以写一个带回溯的单调队列,然后调两个晚上

其实是因为我第一次写线段树做法的时候脑子抽了,点分治的时候把每个点的儿子 \(reverse\) 之后再做了一遍……然后它就 T 了……然后我看一眼讨论区,发现 “单调队列” 四个字,然后就想都没想就直接 \(Dfs\) 的时候用单调队列维护……然后它 WA 了……然后我发现好像要回溯……然后我写了个回溯……然后它 T 了(原因和解法二中的差不多)……然后我就调了两个晚上……

代码

线段树

#include <algorithm>
#include <cstdio>
#include <iostream>
#include <vector>

#define pb push_back
#define mkp make_pair
#define fi first
#define se second

using namespace std;

const int _ = 2e5 + 7;
const int __ = 2e7 + 7;
const int inf = 2e9 + 7;

int n;

struct SGT {
#define mid ((l + r) >> 1)

  int rt, tot, maxn[__], ls[__], rs[__];

  void clear() { rt = tot = 0; }

  void Modify(int &k, int l, int r, int x, int w) {
    if (!k) k = ++tot, maxn[k] = -inf, ls[k] = rs[k] = 0;
    if (l == r) return (void)(maxn[k] = max(maxn[k], w));
    if (x <= mid) Modify(ls[k], l, mid, x, w);
    else Modify(rs[k], mid + 1, r, x, w);
    maxn[k] = max(ls[k] ? maxn[ls[k]] : -inf, rs[k] ? maxn[rs[k]] : -inf);
  }

  void Modify(int x, int w) { Modify(rt, 1, n, x, w); }

  int Query(int &k, int l, int r, int x, int y) {
    if (!k) k = ++tot, maxn[k] = -inf, ls[k] = rs[k] = 0;
    if (l >= x and r <= y) return maxn[k];
    int t1 = -inf, t2 = -inf;
    if (x <= mid) t1 = Query(ls[k], l, mid, x, y);
    if (y > mid) t2 = Query(rs[k], mid + 1, r, x, y);
    return max(t1, t2);
  }

  int Query(int l, int r) {
    if (l > r) return -inf;
    maxn[0] = inf;
    return Query(rt, 1, n, l, r);
  }

#undef mid
} S, T;

int m, L, R, val[_], ans = -inf;
vector<pair<int, int>> to[_];

int gi() {
  int x = 0; bool f = 0; char c = getchar();
  while (!isdigit(c) and c != '-') c = getchar();
  if (c == '-') f = 1, c = getchar();
  while (isdigit(c)) x = (x << 3) + (x << 1) + c - '0', c = getchar();
  return f ? -x : x;
}

void Init() {
  n = gi(), m = gi(), L = gi(), R = gi();
  for (int i = 1; i <= m; ++i) val[i] = gi();
  for (int i = 1, x, y, c; i < n; ++i) {
    x = gi(), y = gi(), c = gi();
    to[x].pb(mkp(c, y));
    to[y].pb(mkp(c, x));
  }
  for (int i = 1; i <= n; ++i) sort(to[i].begin(), to[i].end());
}

int rt, minx, numV, sz[_], top;
pair<int, int> box[_];
bool vis[_];

void GetSz(int u, int fa) {
  ++numV, sz[u] = 1;
  for (auto x: to[u])
    if (!vis[x.se] and x.se != fa) GetSz(x.se, u), sz[u] += sz[x.se];
}

void FindRt(int u, int fa) {
  int maxsz = 0;
  for (auto x: to[u])
    if (!vis[x.se] and x.se != fa) FindRt(x.se, u), maxsz = max(maxsz, sz[x.se]);
  maxsz = max(maxsz, numV - sz[u]);
  if (maxsz < minx) rt = u, minx = maxsz;
}

void Stat(int u, int fa, int w, int lst, int t, int dis) {
  if (dis > R) return;
  ans = max(ans, w + S.Query(max(1, L - dis), min(n, R - dis)));
  ans = max(ans, w - t + T.Query(max(1, L - dis), min(n, R - dis)));
  if (dis >= L) ans = max(ans, w);
  for (auto x: to[u])
    if (!vis[x.se] and x.se != fa) Stat(x.se, u, w + (x.fi != lst) * val[x.fi], x.fi, t, dis + 1);
}

void Cont(int u, int fa, int w, int lst, int dis) {
  if (dis > R) return;
  T.Modify(dis, w), box[++top] = mkp(dis, w);
  for (auto x: to[u])
    if (!vis[x.se] and x.se != fa) Cont(x.se, u, w + (x.fi != lst) * val[x.fi], x.fi, dis + 1);
}

void Calc() {
  int lst = 0; top = 0;
  for (auto x: to[rt]) {
    int v = x.se;
    if (vis[v]) continue;
    if (x.fi != lst) {
      while (top) S.Modify(box[top].fi, box[top].se), --top;
      T.clear(), lst = x.fi;
    }
    Stat(v, rt, val[x.fi], x.fi, val[x.fi], 1);
    Cont(v, rt, val[x.fi], x.fi, 1);
  }
  S.clear(), T.clear();
}

void Work(int u) {
  rt = 0, minx = inf, numV = 0;
  GetSz(u, 0);
  FindRt(u, 0);
  
  int tmp = rt; vis[rt] = 1;
  for (auto x: to[tmp])
    if (!vis[x.se]) Work(x.se);
  rt = tmp, vis[rt] = 0;

  Calc();
}

int main() {
  Init();
  Work(1);
  cout << ans << endl;
  return 0;
}

带回溯单调队列

#include <algorithm>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>

#define pb push_back
#define mkp make_pair
#define fi first
#define se second

using namespace std;

const int _ = 2e5 + 7;
const int __ = 2e6 + 7;
const int inf = 2e9 + 7;

int n;

struct SGT {
#define mid ((l + r) >> 1)

  int lc, rc, maxn[_], box[_], top, val, q[_], t1, t2, cnt, tot;
  pair<int, int> a[__];
  struct NODE { int lc, rc, t1, t2, st; } rev[_];

  void Init() { memset(maxn, -0x3f, sizeof maxn); lc = rc = n + 1, t1 = 1, t2 = 0; }

  void Recall() {
    t1 = rev[cnt].t1, t2 = rev[cnt].t2, lc = rev[cnt].lc, rc = rev[cnt].rc;
    while (tot > rev[cnt].st) q[a[tot].fi] = a[tot].se, --tot;
    --cnt;
  }
  
  void clear() {
    int x = 0;
    while (top) x = max(x, box[top]), maxn[box[top--]] = -inf;
    while (cnt and lc <= x) Recall();
  }

  void reset(int p) { lc = rc = p, t1 = 1, t2 = 0, cnt = 0; }

  void Modify(int x, int w) {
    box[++top] = x;
    maxn[x] = max(maxn[x], w);
    while (cnt and lc <= x) Recall();
  }

  int Query(int l, int r) {
    rev[++cnt] = { lc, rc, t1, t2, tot };
    while (lc > l) {
      --lc;
      while (t2 >= t1 and maxn[q[t2]] <= maxn[lc]) {
        a[++tot] = mkp(t2, q[t2]);
        --t2;
      }
      q[++t2] = lc;
    }
    rc = min(rc, r);
    while (t1 <= t2 and (q[t1] > rc or q[t1] < lc)) a[++tot] = mkp(t1, q[t1]), ++t1;
    return t1 > t2 ? -inf : maxn[q[t1]];
  }


#undef mid
} S, T;

int m, L, R, val[_], ans = -inf;
vector<pair<int, int>> to[_];

int gi() {
  int x = 0; bool f = 0; char c = getchar();
  while (!isdigit(c) and c != '-') c = getchar();
  if (c == '-') f = 1, c = getchar();
  while (isdigit(c)) x = (x << 3) + (x << 1) + c - '0', c = getchar();
  return f ? -x : x;
}

void Init() {
  n = gi(), m = gi(), L = gi(), R = gi();
  for (int i = 1; i <= m; ++i) val[i] = gi();
  for (int i = 1, x, y, c; i < n; ++i) {
    x = gi(), y = gi(), c = gi();
    to[x].pb(mkp(c, y));
    to[y].pb(mkp(c, x));
  }
  for (int i = 1; i <= n; ++i) sort(to[i].begin(), to[i].end());
  S.Init(), T.Init();
}

int rt, minx, numV, sz[_], top, dep[_];
pair<int, int> box[_];
bool vis[_];

void GetSz(int u, int fa) {
  ++numV, sz[u] = 1, dep[u] = 1;
  for (auto x: to[u])
    if (!vis[x.se] and x.se != fa) GetSz(x.se, u), sz[u] += sz[x.se], dep[u] = max(dep[u], dep[x.se] + 1);
}

void FindRt(int u, int fa) {
  int maxsz = 0;
  for (auto x: to[u])
    if (!vis[x.se] and x.se != fa) FindRt(x.se, u), maxsz = max(maxsz, sz[x.se]);
  maxsz = max(maxsz, numV - sz[u]);
  if (maxsz < minx) rt = u, minx = maxsz;
}

void Stat(int u, int fa, int w, int lst, int t, int dis) {
  if (dis > R) return;
  ans = max(ans, w + S.Query(max(1, L - dis), min(n, R - dis)));
  ans = max(ans, w - t + T.Query(max(1, L - dis), min(n, R - dis)));
  if (dis >= L) ans = max(ans, w);
  for (auto x: to[u])
    if (!vis[x.se] and x.se != fa) Stat(x.se, u, w + (x.fi != lst) * val[x.fi], x.fi, t, dis + 1);
  S.Recall();
  T.Recall();
}

void Cont(int u, int fa, int w, int lst, int dis) {
  if (dis > R) return;
  T.Modify(dis, w);
  box[++top] = mkp(dis, w);
  for (auto x: to[u])
    if (!vis[x.se] and x.se != fa) Cont(x.se, u, w + (x.fi != lst) * val[x.fi], x.fi, dis + 1);
}

void Calc() {
  int lst = 0; top = 0;

  S.reset(dep[rt] + 1), T.reset(dep[rt] + 1);
  while (S.lc > L) S.Query(S.lc - 1, R);
  while (T.lc > L) T.Query(T.lc - 1, R);

  for (auto x: to[rt]) {
    int v = x.se;
    if (vis[v]) continue;
    if (x.fi != lst) {
      while (top) S.Modify(box[top].fi, box[top].se), --top;
      T.clear();
      lst = x.fi;
    }
    while (S.lc > L) S.Query(S.lc - 1, R);
    while (T.lc > L) T.Query(T.lc - 1, R);
    Stat(v, rt, val[x.fi], x.fi, val[x.fi], 1);
    Cont(v, rt, val[x.fi], x.fi, 1);
  }
  S.clear();
  T.clear();
}

void Work(int u) {
  rt = 0, minx = inf, numV = 0;
  GetSz(u, 0);
  FindRt(u, 0);
  
  int tmp = rt; vis[rt] = 1;
  for (auto x: to[tmp])
    if (!vis[x.se]) Work(x.se);
  rt = tmp, vis[rt] = 0;
  
  Calc();
}

int main() {
  Init();
  Work(1);
  cout << ans << endl;
  return 0;
}
posted @ 2021-01-18 22:48  BruceW  阅读(197)  评论(0编辑  收藏  举报