题解 [CF504E] Misha and LCP on Tree
终于有一个可以二分+hash艹的题了?
哦三个 log 过不去呀
那我来口胡一个大常数 \(O(n\log^2 n)\) 做法:
查询两个串的时候在两棵 LCT 上将两个串分别 split 出来
在其中一个串上做平衡树上二分,另一个串用 kth+splay 协助完成二分
大概比三个 log 慢吧
- 关于树上路径 hash 值:利用 hash 值的可减性,可以预处理出每个点到根的 hash 值后搭配 lca 和 k 级祖先做到一个 log
于是就可以两个 log 了,还是过不去
于是继续优化:发现瓶颈在于二分和 k 级祖先
于是可以长链剖分优化求 k 级祖先,这样就是一个 log 的了
然后这题卡常,还是过不去
那么 RMQ 优化求 lca 就可以卡过了
复杂度 \(O((n+m)\log n)\)
还有一个思路不同的做法:
每个串树剖后形成 log 个串
可以从前往后在每个串上走,走到第一个不一样的再在这个区间内二分
这样看起来仍然需要求 k 级祖先
但是注意树剖后每条重链的 dfs 序是连续的,所以在这条重链上的 k 级祖先可以 \(O(1)\) 求
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 300010
#define fir first
#define sec second
#define ll long long
//#define int long long
int n, m;
char s[N];
pair<int, int> st[25][N<<1];
ll h[N], rh[N], pw[N], inv[N];
const ll base=13131, mod=1206927149;
int head[N], fa[23][N], dep[N], mdep[N], top[N], lg[N<<1], *up[N], *down[N], mson[N], pos[N], ecnt, tot;
struct edge{int to, next;}e[N<<1];
inline void add(int s, int t) {e[++ecnt]={t, head[s]}; head[s]=ecnt;}
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
void dfs1(int u, int pa) {
mdep[u]=dep[u];
h[u]=(h[pa]+s[u]*pw[dep[u]-1])%mod;
rh[u]=(rh[pa]*base+s[u])%mod;
st[0][pos[u]=++tot]={dep[u], u};
for (int i=1; i<23; ++i)
if (dep[u]>=1<<i) fa[i][u]=fa[i-1][fa[i-1][u]];
else break;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==pa) continue;
fa[0][v]=u;
dep[v]=dep[u]+1;
dfs1(v, u);
if (mdep[v]>mdep[u]) mdep[u]=mdep[v], mson[u]=v;
st[0][++tot]={dep[u], u};
}
}
void dfs2(int u, int pa, int t) {
top[u]=t;
if (u==t) {
up[u]=new int[mdep[u]-dep[u]+5];
down[u]=new int[mdep[u]-dep[u]+5];
for (int pos=0,now=u; pos<=mdep[u]-dep[u]; now=fa[0][now]) up[u][pos++]=now;
for (int pos=0,now=u; pos<=mdep[u]-dep[u]; now=mson[now]) down[u][pos++]=now;
}
if (!mson[u]) return ;
dfs2(mson[u], u, t);
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==pa||v==mson[u]) continue;
dfs2(v, u, v);
}
}
// int lca(int a, int b) {
// if (dep[a]<dep[b]) swap(a, b);
// while (dep[a]>dep[b]) a=fa[lg[dep[a]-dep[b]]-1][a];
// if (a==b) return a;
// for (int i=lg[dep[a]]-1; ~i; --i)
// if (fa[i][a]!=fa[i][b])
// a=fa[i][a], b=fa[i][b];
// return fa[0][a];
// }
pair<int, int> qmax(int l, int r) {
// cout<<"qmax: "<<l<<' '<<r<<endl;
int t=lg[r-l+1]-1;
return st[t][l].fir<=st[t][r-(1<<t)+1].fir?st[t][l]:st[t][r-(1<<t)+1];
}
int lca(int a, int b) {return qmax(min(pos[a], pos[b]), max(pos[a], pos[b])).sec;}
int anc(int u, int k) {
// cout<<"anc: "<<u<<' '<<k<<endl;
if (!k) return u;
u=fa[31-__builtin_clz(k)][u];
k^=1<<(31-__builtin_clz(k));
if (!u) return 0;
int dis=dep[u]-dep[top[u]];
if (k<=dis) return down[top[u]][dis-k];
else return up[top[u]][k-dis];
}
ll qhash(int u, int v, int t, int len) {
if (!len) return 0;
ll ans;
int len1=dep[u]-dep[t]+1, len2=dep[v]-dep[t];
if (len<=len1) ans=(rh[u]-rh[anc(u, len)]*pw[len])%mod;
else {
ans=(rh[u]-rh[fa[0][t]]*pw[len1])%mod;
int tem=anc(v, len2-(len-len1));
ans=(ans+(h[tem]-h[t])*inv[dep[t]]%mod*pw[len1])%mod;
}
return (ans%mod+mod)%mod;
}
signed main()
{
scanf("%d%s", &n, s+1);
memset(head, -1, sizeof(head));
for (int i=1,u,v; i<n; ++i) {
scanf("%d%d", &u, &v);
add(u, v); add(v, u);
}
pw[0]=inv[0]=1; pw[1]=base; inv[1]=qpow(base, mod-2);
for (int i=2; i<=n; ++i) pw[i]=pw[i-1]*base%mod;
for (int i=2; i<=n; ++i) inv[i]=inv[i-1]*inv[1]%mod;
dep[1]=1; dfs1(1, 0); dfs2(1, 0, 1);
for (int i=1; i<=tot; ++i) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
int t=lg[tot]-1;
for (int i=1; i<=t; ++i)
for (int j=1,len=1<<i-1; j+(1<<i)-1<=tot; ++j)
st[i][j]=st[i-1][j].fir<=st[i-1][j+len].fir?st[i-1][j]:st[i-1][j+len];
// cout<<"st0: "; for (int i=1; i<=tot; ++i) cout<<"("<<st[0][i].fir<<','<<st[0][i].sec<<") "; cout<<endl;
scanf("%d", &m);
for (int i=1,a,b,c,d; i<=m; ++i) {
scanf("%d%d%d%d", &a, &b, &c, &d);
int t1=lca(a, b), t2=lca(c, d);
int dis1=dep[a]+dep[b]-2*dep[t1], dis2=dep[c]+dep[d]-2*dep[t2];
// cout<<"dis: "<<dis1<<' '<<dis2<<endl;
int l=0, r=min(dis1, dis2)+1, mid;
while (l<=r) {
mid=(l+r)>>1;
// cout<<"mid: "<<mid<<endl;
// cout<<"qhash: "<<qhash(a, b, t1, mid)<<' '<<qhash(c, d, t2, mid)<<endl;
if (qhash(a, b, t1, mid)==qhash(c, d, t2, mid)) l=mid+1;
else r=mid-1;
}
printf("%d\n", l-1);
}
return 0;
}