奇怪的主席树题 题解

前言

前置知识:可持久化线段树,字符串hash,线段树二分。

Description

给你一棵 \(n\) 个点以 \(1\) 为根节点的树,每一个节点都代表一个 \(m\) 个字符的字符串。

每一个节点都与它的父节点有一个字符只差,具体的 \(S_i\) ,代表 \(i\) 节点的字符串。设 \(u\)\(v\) 的父节点,有 \(pos_v\)\(c_v\) ,代表:

\(S_{v,j}=S_{u,j}(1\le j \le m且j\not= pos_v)\)

\(S_{v,pos_v}=c_v\)

要求你对这 \(n\) 个字符串按字典序排序后,输出排序后的编号。

输入格式

第一行输入 \(n,m\),如题意。

第二行输入 \(S_1\)

接下来 \(n-1\) 行,第 \(i\) 行输入 \(p_{i+1},pos_{i+1},c_{i+1}\)。其中 \(p_{i+1}\) 为节点 \(i+1\) 的父亲。

输出格式

\(n\) 个数,为按照字典序排序后的编号。

数据范围

\[对于所有数据\quad 1\le n,m\le 10^5,pos_i\le m,p_i\le i\\ 对于30\%的数据\quad 1\le n,m\le 5000\]

Solution

对于 \(30\%\) 的数据,我们可以求暴力 \(O(nm)\) 的求出每个字符串的哈希值,对于比较两个字符串大小,我们可以二分第一个位置使得这两个串的哈希值不同,然后比较这两个位置大小即可,这样排序的复杂度是 \(O(n\log_2n \log_2m)\),总复杂度为 \(O(n\log_2n \log_2m+nm)\)

对于 \(100\%\) 的数据,我们考虑优化 \(30\%\) 中求字符串哈希的过程。

我们可以通过一个可持久化线段树来维护每个字符的hash值,这个对于学过主席树的人来说不难实现, 每次修改把父亲的根节点复制成一个新的节点, 递归往下更新。这样每次查询区间哈希值比较的复杂度会变成 \(O(n\log_2n \log_2^2m)\) ,考虑这么将一个 \(\log_2m\) 优化掉,不难想到可以用线段树二分来比较,总复杂度 \(O(n\log_2n \log_2m)\)

CODE

#include<bits/stdc++.h>
using namespace std;
typedef unsigned long long ull;
const ull base=31;
const int N=1e5+10;
int n,m,x,y,rt[N],ls[N*40],rs[N*40],len[N*40],ans[N],tot;
ull pw[N],sum[N*40];
char s[N],ch[5];
void build(int &rt,int l,int r){
	rt=++tot;len[rt]=r-l+1;
	if(l==r){sum[rt]=s[l];return;}
	int mid=(l+r)>>1;
	build(ls[rt],l,mid);build(rs[rt],mid+1,r);
	sum[rt]=sum[ls[rt]]+sum[rs[rt]]*pw[len[ls[rt]]];
}
void update(int pre,int &rt,int l,int r,int x,int y){
	rt=++tot;ls[rt]=ls[pre];rs[rt]=rs[pre];
	len[rt]=len[pre];sum[rt]=sum[pre];
	if(l==r){sum[rt]=y;return;}
	int mid=(l+r)>>1;
	if(x<=mid)update(ls[pre],ls[rt],l,mid,x,y);
	else update(rs[pre],rs[rt],mid+1,r,x,y);
	sum[rt]=sum[ls[rt]]+sum[rs[rt]]*pw[len[ls[rt]]];
}
bool solve(int x,int y,int l,int r){
	if(l==r)return sum[x]<sum[y];
	int mid=(l+r)>>1;
	if(sum[ls[x]]==sum[ls[y]])return solve(rs[x],rs[y],mid+1,r);
	return solve(ls[x],ls[y],l,mid);
} 
inline bool cmp(int x,int y){
	if(sum[rt[x]]==sum[rt[y]])return x<y;
	return solve(rt[x],rt[y],1,m);
}
int main(){
	freopen("z.in","r",stdin);
	freopen("z.out","w",stdout);
	scanf("%d%d",&n,&m);pw[0]=1;
	for(int i=1;i<=m;i++)pw[i]=pw[i-1]*base;
	for(int i=1;i<=n;i++)ans[i]=i;
	scanf("%s",s+1);build(rt[1],1,m);
	for(int i=2;i<=n;i++){
		scanf("%d%d%s",&x,&y,ch+1);
		update(rt[x],rt[i],1,m,y,ch[1]);
	}
	sort(ans+1,ans+1+n,cmp);
	for(int i=1;i<=n;i++){
		printf("%d",ans[i]);
		if(i!=n)printf(" ");
	}
	return 0;
}
posted @ 2022-11-22 23:29  SZBR_yzh  阅读(16)  评论(0编辑  收藏  举报