[CF504E] Misha and LCP on Tree 题解

Description

给定一棵 \(n\) 个节点的树,每个节点有一个小写字母。

\(m\) 组询问,每组询问为树上 \(a \to b\)\(c \to d\) 组成的字符串的最长公共前缀。

\(n \le 3 \times 10^5,m \le 10^6\)

Sol

观察到有最长公共前缀,我们可以采用树上Hash + 二分的做法来解决,树上的哈希我们可以正着做一遍根到每个点的哈希,再做一遍每个点到根的哈希,查询时拼起来即可。

然后要二分的话,我们只能二分该点是路径上的第几个点,所以还要写一个长链剖分找 \(k\) 级祖先。

时间复杂度 \(O((n+m)\log n)\)

Code

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast")
#include<bits/stdc++.h>
#define add(a,b) (a+b>=Mod?a+b-Mod:a+b)
#define dec(a,b) (a<b?a-b+Mod:a-b)
#define re register
#define il inline
using namespace std;
namespace IO {
char buf_[1 << 21], *p1_ = buf_, *p2_ = buf_;
#define ch()                                                                 \
  (p1_ == p2_ &&                                                             \
           (p2_ = (p1_ = buf_) + fread(buf_, 1, 1 << 21, stdin), p1_ == p2_) \
       ? EOF                                                                 \
       : *p1_++)
inline int Read() {
  int s = 0, f = 1;
  char x = ch();
  for (; x < '0' || x > '9'; x = ch())
    if (x == '-') f = -1;
  for (; x >= '0' && x <= '9'; x = ch()) s = (s * 10) + (x & 15);
  return f == 1 ? s : -s;
}
char buf__[1 << 21];
int pos__ = -1;
inline void flush() { fwrite(buf__, 1, pos__ + 1, stdout), pos__ = -1; }
inline void pc(char x) {
  if (pos__ == (1 << 21) - 1) flush();
  buf__[++pos__] = x;
}
inline void out(int x) {
  char k[30];
  int pos = 0;
  if (!x) return pc('0');
  if (x < 0) pc('-'), x = -x;
  while (x) k[++pos] = (x % 10) | 48, x /= 10;
  for (int i = pos; i; i--) pc(k[i]);
}
inline void out(string x) {
  int k = x.size();
  for (int i = 0; i < k; i++) pc(x[i]);
}
}  // namespace IO
using namespace IO;
int first[600005], nxt[600005], to[600005], tot;
void Add(int x, int y) {nxt[++tot] = first[x]; first[x] = tot; to[tot] = y;}
int n, m, a[300005], dep[300005], dfn[900005], ind, minn[900005][22], lg[900005], st[300005], mdep[300005];
int son[300005], top[300005], father[300005], len[300005], ys[300005], f[300005][20];
char ch[300005];
vector<int> UP[300005], DOWN[300005];
il int qpow(int a, int b, int Mod) {
    int res = 1;
    while(b) {
        if(b & 1)  res = 1ll * res * a % Mod;
        a = 1ll * a * a % Mod;
        b >>= 1;
    }
    return res;
}
il void dfs1(int u, int fa) {
    father[u] = f[u][0] = fa;
    for(int i = 1; i <= 19; i++)  f[u][i] = f[f[u][i - 1]][i - 1];
    dfn[++ind] = u; mdep[u] = dep[u] = dep[fa] + 1;
    st[u] = ind;
    for(re int e = first[u]; e; e = nxt[e]) {
        int v = to[e];
        if(v == fa)  continue;
        dfs1(v, u); dfn[++ind] = u;
        mdep[u] = max(mdep[u], mdep[v]);
        if(mdep[v] > mdep[son[u]])  son[u] = v;
    }
}
il void dfs2(int u, int tp) {
    ++len[tp];
    top[u] = tp;
    if(son[u])  dfs2(son[u], tp);
    for(re int e = first[u]; e; e = nxt[e])
        if(!top[to[e]])  dfs2(to[e], to[e]);
}
il int my_min(int a, int b) {
    return (dep[a] < dep[b]) ? a : b;
}
il void prework() {
    for(re int i = 1; i <= ind; i++)  minn[i][0] = dfn[i];
    for(re int i = 1; (1 << i) <= ind; i++)
        for(re int j = 1; (1 << i) + j - 1 <= ind; j++)
            minn[j][i] = my_min(minn[j][i - 1], minn[j + (1 << (i - 1))][i - 1]);
}
il int getlca(int x, int y) {
    if(st[x] > st[y])  swap(x, y);
    int k = lg[st[y] - st[x] + 1];
    return my_min(minn[st[x]][k], minn[st[y] - (1 << k) + 1][k]);
}
il int getanc(int x, int k) {
    if(k >= dep[x])  return 0;
    if(k == 0)  return x;
    x = f[x][ys[k]]; k -= (1 << ys[k]);
    if(!k)  return x;
    int Tp = top[x];
    if(dep[Tp] == dep[x] - k)  return Tp;
    if(dep[Tp] > dep[x] - k)  return UP[Tp][dep[Tp] - dep[x] + k - 1];
    else  return DOWN[Tp][dep[x] - k - dep[Tp] - 1];
}
struct node {
    int base, Mod, zHsh[300005], fHsh[300005], pr[300005], ny[300005];
    il void prework() {
        pr[1] = 1;
        for(re int i = 2; i <= n; i++)  pr[i] = 1ll * pr[i - 1] * base % Mod;
        int NY = qpow(base, Mod - 2, Mod);
        ny[1] = 1;
        for(re int i = 2; i <= n; i++)  ny[i] = 1ll * ny[i - 1] * NY % Mod;
        fHsh[1] = a[1];
    }
    il void getHsh(int u, int fa) {
        zHsh[u] = (1ll * zHsh[fa] * base + a[u]) % Mod;
        for(re int e = first[u]; e; e = nxt[e]) {
            int v = to[e]; if(v == fa)  continue;
            fHsh[v] = (fHsh[u] + 1ll * a[v] * pr[dep[v]]) % Mod;
            getHsh(v, u);
        }
    }
    il int getroad(int x, int y) {
        int Lca = getlca(x, y);
        int Hsh1 = 1ll * dec(fHsh[x], fHsh[father[Lca]]) * ny[dep[Lca]] % Mod;
        if(y == Lca)  return Hsh1;
        int Hsh2 = (zHsh[y] - 1ll * zHsh[Lca] * pr[dep[y] - dep[Lca] + 1]) % Mod;
        if(Hsh2 < 0)  Hsh2 += Mod;
        return (1ll * Hsh1 * pr[dep[y] - dep[Lca] + 1] + Hsh2) % Mod;
    }
}T[2];
int solve(int l1, int r1, int l2, int r2) {
    int Lca1 = getlca(l1, r1), Lca2 = getlca(l2, r2);
    int len1 = dep[l1] + dep[r1] - 2 * dep[Lca1] + 1, len2 = dep[l2] + dep[r2] - 2 * dep[Lca2] + 1;
    int l = 1, r = min(len1, len2), ans = 0;
    while(l <= r) {
        int mid = (l + r) >> 1;
        int anc1, anc2;
        if(dep[l1] - dep[Lca1] + 1 >= mid)  anc1 = getanc(l1, mid - 1);
        else  anc1 = getanc(r1, len1 - mid);
        if(dep[l2] - dep[Lca2] + 1 >= mid)  anc2 = getanc(l2, mid - 1);
        else  anc2 = getanc(r2, len2 - mid);
        //cout << T[0].getroad(l1, anc1) << " " << T[0].getroad(l2, anc2) << " " << T[1].getroad(l1, anc1) << " " << T[1].getroad(l2, anc2) << endl;
        if(T[0].getroad(l1, anc1) == T[0].getroad(l2, anc2)/* && T[1].getroad(l1, anc1) == T[1].getroad(l2, anc2)*/)
            l = mid + 1, ans = mid;
        else  r = mid - 1;
    }
    return ans;
}
signed main() {
//    freopen("1.in", "r", stdin);
//    freopen("1.ans", "w", stdout);
    lg[0] = -1;
    for(re int i = 1; i <= 900000; i++)  lg[i] = lg[i >> 1] + 1;
    n = Read();
    for(re int i = 1; i <= n; i++) {
        char wsssstc = ch();
        while(!isalpha(wsssstc))  wsssstc = ch();
        a[i] = wsssstc - 'a' + 1;
    }
    for(re int i = 1, x, y; i < n; i++) {
        x = Read(), y = Read();
        Add(x, y); Add(y, x);
    }
    dfs1(1, 0); dfs2(1, 1); prework();
    T[0].base = 133, T[0].Mod = 1000000007; T[0].prework();
    //T[1].base = 233, T[1].Mod = 998244353; T[1].prework();
    T[0].getHsh(1, 0); //T[1].getHsh(1, 0);
    //T[0].Print(); T[1].Print();
    for(re int i = 1; i <= n; i++) {
        if(top[i] != i)  continue;
        int l = 0, x = i;
        UP[i].resize(len[i]); DOWN[i].resize(len[i]);
        while(l < len[i] && x)
            x = f[x][0], UP[i][l] = x, ++l;
        l = 0, x = i;
        while(l < len[i] && x)
            x = son[x], DOWN[i][l] = x, ++l;
    }
    for(re int i = 1; i <= n; i++)
        for(int j = 20; j >= 0; j--)
            if(i & (1 << j))   {ys[i] = j; break;}
    m = Read();
    for(re int i = 1, l1, r1, l2, r2; i <= m; i++) {
        l1 = Read(), r1 = Read(), l2 = Read(), r2 = Read();
        out(solve(l1, r1, l2, r2)); pc('\n');
    }
    flush();
    return 0;
}
posted @ 2020-10-29 13:56  verjun  阅读(91)  评论(0编辑  收藏  举报