题解 [CF504E] Misha and LCP on Tree

传送门

终于有一个可以二分+hash艹的题了?
哦三个 log 过不去呀
那我来口胡一个大常数 \(O(n\log^2 n)\) 做法:
查询两个串的时候在两棵 LCT 上将两个串分别 split 出来
在其中一个串上做平衡树上二分,另一个串用 kth+splay 协助完成二分
大概比三个 log 慢吧

  • 关于树上路径 hash 值:利用 hash 值的可减性,可以预处理出每个点到根的 hash 值后搭配 lca 和 k 级祖先做到一个 log

于是就可以两个 log 了,还是过不去
于是继续优化:发现瓶颈在于二分和 k 级祖先
于是可以长链剖分优化求 k 级祖先,这样就是一个 log 的了
然后这题卡常,还是过不去
那么 RMQ 优化求 lca 就可以卡过了
复杂度 \(O((n+m)\log n)\)

还有一个思路不同的做法:
每个串树剖后形成 log 个串
可以从前往后在每个串上走,走到第一个不一样的再在这个区间内二分
这样看起来仍然需要求 k 级祖先
但是注意树剖后每条重链的 dfs 序是连续的,所以在这条重链上的 k 级祖先可以 \(O(1)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 300010
#define fir first
#define sec second
#define ll long long
//#define int long long

int n, m;
char s[N];
pair<int, int> st[25][N<<1];
ll h[N], rh[N], pw[N], inv[N];
const ll base=13131, mod=1206927149;
int head[N], fa[23][N], dep[N], mdep[N], top[N], lg[N<<1], *up[N], *down[N], mson[N], pos[N], ecnt, tot;
struct edge{int to, next;}e[N<<1];
inline void add(int s, int t) {e[++ecnt]={t, head[s]}; head[s]=ecnt;}
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

void dfs1(int u, int pa) {
	mdep[u]=dep[u];
	h[u]=(h[pa]+s[u]*pw[dep[u]-1])%mod;
	rh[u]=(rh[pa]*base+s[u])%mod;
	st[0][pos[u]=++tot]={dep[u], u};
	for (int i=1; i<23; ++i)
		if (dep[u]>=1<<i) fa[i][u]=fa[i-1][fa[i-1][u]];
		else break;
	for (int i=head[u],v; ~i; i=e[i].next) {
		v = e[i].to;
		if (v==pa) continue;
		fa[0][v]=u;
		dep[v]=dep[u]+1;
		dfs1(v, u);
		if (mdep[v]>mdep[u]) mdep[u]=mdep[v], mson[u]=v;
		st[0][++tot]={dep[u], u};
	}
}

void dfs2(int u, int pa, int t) {
	top[u]=t;
	if (u==t) {
		up[u]=new int[mdep[u]-dep[u]+5];
		down[u]=new int[mdep[u]-dep[u]+5];
		for (int pos=0,now=u; pos<=mdep[u]-dep[u]; now=fa[0][now]) up[u][pos++]=now;
		for (int pos=0,now=u; pos<=mdep[u]-dep[u]; now=mson[now]) down[u][pos++]=now;
	}
	if (!mson[u]) return ;
	dfs2(mson[u], u, t);
	for (int i=head[u],v; ~i; i=e[i].next) {
		v = e[i].to;
		if (v==pa||v==mson[u]) continue;
		dfs2(v, u, v);
	}
}

// int lca(int a, int b) {
// 	if (dep[a]<dep[b]) swap(a, b);
// 	while (dep[a]>dep[b]) a=fa[lg[dep[a]-dep[b]]-1][a];
// 	if (a==b) return a;
// 	for (int i=lg[dep[a]]-1; ~i; --i)
// 		if (fa[i][a]!=fa[i][b])
// 			a=fa[i][a], b=fa[i][b];
// 	return fa[0][a];
// }

pair<int, int> qmax(int l, int r) {
	// cout<<"qmax: "<<l<<' '<<r<<endl;
	int t=lg[r-l+1]-1;
	return st[t][l].fir<=st[t][r-(1<<t)+1].fir?st[t][l]:st[t][r-(1<<t)+1];
}

int lca(int a, int b) {return qmax(min(pos[a], pos[b]), max(pos[a], pos[b])).sec;}

int anc(int u, int k) {
	// cout<<"anc: "<<u<<' '<<k<<endl;
	if (!k) return u;
	u=fa[31-__builtin_clz(k)][u];
	k^=1<<(31-__builtin_clz(k));
	if (!u) return 0;
	int dis=dep[u]-dep[top[u]];
	if (k<=dis) return down[top[u]][dis-k];
	else return up[top[u]][k-dis];
}

ll qhash(int u, int v, int t, int len) {
	if (!len) return 0;
	ll ans;
	int len1=dep[u]-dep[t]+1, len2=dep[v]-dep[t];
	if (len<=len1) ans=(rh[u]-rh[anc(u, len)]*pw[len])%mod;
	else {
		ans=(rh[u]-rh[fa[0][t]]*pw[len1])%mod;
		int tem=anc(v, len2-(len-len1));
		ans=(ans+(h[tem]-h[t])*inv[dep[t]]%mod*pw[len1])%mod;
	}
	return (ans%mod+mod)%mod;
}

signed main()
{
	scanf("%d%s", &n, s+1);
	memset(head, -1, sizeof(head));
	for (int i=1,u,v; i<n; ++i) {
		scanf("%d%d", &u, &v);
		add(u, v); add(v, u);
	}
	pw[0]=inv[0]=1; pw[1]=base; inv[1]=qpow(base, mod-2);
	for (int i=2; i<=n; ++i) pw[i]=pw[i-1]*base%mod;
	for (int i=2; i<=n; ++i) inv[i]=inv[i-1]*inv[1]%mod;
	dep[1]=1; dfs1(1, 0); dfs2(1, 0, 1);
	for (int i=1; i<=tot; ++i) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
	int t=lg[tot]-1;
	for (int i=1; i<=t; ++i)
		for (int j=1,len=1<<i-1; j+(1<<i)-1<=tot; ++j)
			st[i][j]=st[i-1][j].fir<=st[i-1][j+len].fir?st[i-1][j]:st[i-1][j+len];
	// cout<<"st0: "; for (int i=1; i<=tot; ++i) cout<<"("<<st[0][i].fir<<','<<st[0][i].sec<<") "; cout<<endl;
	scanf("%d", &m);
	for (int i=1,a,b,c,d; i<=m; ++i) {
		scanf("%d%d%d%d", &a, &b, &c, &d);
		int t1=lca(a, b), t2=lca(c, d);
		int dis1=dep[a]+dep[b]-2*dep[t1], dis2=dep[c]+dep[d]-2*dep[t2];
		// cout<<"dis: "<<dis1<<' '<<dis2<<endl;
		int l=0, r=min(dis1, dis2)+1, mid;	
		while (l<=r) {
			mid=(l+r)>>1;
			// cout<<"mid: "<<mid<<endl;
			// cout<<"qhash: "<<qhash(a, b, t1, mid)<<' '<<qhash(c, d, t2, mid)<<endl;
			if (qhash(a, b, t1, mid)==qhash(c, d, t2, mid)) l=mid+1;
			else r=mid-1;
		}
		printf("%d\n", l-1);
	}
	
	return 0;
}
posted @ 2022-04-27 20:33  Administrator-09  阅读(1)  评论(0编辑  收藏  举报