奇怪的主席树题 题解
前言
前置知识:可持久化线段树,字符串hash,线段树二分。
Description
给你一棵 \(n\) 个点以 \(1\) 为根节点的树,每一个节点都代表一个 \(m\) 个字符的字符串。
每一个节点都与它的父节点有一个字符只差,具体的 \(S_i\) ,代表 \(i\) 节点的字符串。设 \(u\) 为 \(v\) 的父节点,有 \(pos_v\) 和 \(c_v\) ,代表:
\(S_{v,j}=S_{u,j}(1\le j \le m且j\not= pos_v)\)
\(S_{v,pos_v}=c_v\)
要求你对这 \(n\) 个字符串按字典序排序后,输出排序后的编号。
输入格式
第一行输入 \(n,m\),如题意。
第二行输入 \(S_1\)。
接下来 \(n-1\) 行,第 \(i\) 行输入 \(p_{i+1},pos_{i+1},c_{i+1}\)。其中 \(p_{i+1}\) 为节点 \(i+1\) 的父亲。
输出格式
\(n\) 个数,为按照字典序排序后的编号。
数据范围
Solution
对于 \(30\%\) 的数据,我们可以求暴力 \(O(nm)\) 的求出每个字符串的哈希值,对于比较两个字符串大小,我们可以二分第一个位置使得这两个串的哈希值不同,然后比较这两个位置大小即可,这样排序的复杂度是 \(O(n\log_2n \log_2m)\),总复杂度为 \(O(n\log_2n \log_2m+nm)\)。
对于 \(100\%\) 的数据,我们考虑优化 \(30\%\) 中求字符串哈希的过程。
我们可以通过一个可持久化线段树来维护每个字符的hash值,这个对于学过主席树的人来说不难实现, 每次修改把父亲的根节点复制成一个新的节点, 递归往下更新。这样每次查询区间哈希值比较的复杂度会变成 \(O(n\log_2n \log_2^2m)\) ,考虑这么将一个 \(\log_2m\) 优化掉,不难想到可以用线段树二分来比较,总复杂度 \(O(n\log_2n \log_2m)\)。
CODE
#include<bits/stdc++.h>
using namespace std;
typedef unsigned long long ull;
const ull base=31;
const int N=1e5+10;
int n,m,x,y,rt[N],ls[N*40],rs[N*40],len[N*40],ans[N],tot;
ull pw[N],sum[N*40];
char s[N],ch[5];
void build(int &rt,int l,int r){
rt=++tot;len[rt]=r-l+1;
if(l==r){sum[rt]=s[l];return;}
int mid=(l+r)>>1;
build(ls[rt],l,mid);build(rs[rt],mid+1,r);
sum[rt]=sum[ls[rt]]+sum[rs[rt]]*pw[len[ls[rt]]];
}
void update(int pre,int &rt,int l,int r,int x,int y){
rt=++tot;ls[rt]=ls[pre];rs[rt]=rs[pre];
len[rt]=len[pre];sum[rt]=sum[pre];
if(l==r){sum[rt]=y;return;}
int mid=(l+r)>>1;
if(x<=mid)update(ls[pre],ls[rt],l,mid,x,y);
else update(rs[pre],rs[rt],mid+1,r,x,y);
sum[rt]=sum[ls[rt]]+sum[rs[rt]]*pw[len[ls[rt]]];
}
bool solve(int x,int y,int l,int r){
if(l==r)return sum[x]<sum[y];
int mid=(l+r)>>1;
if(sum[ls[x]]==sum[ls[y]])return solve(rs[x],rs[y],mid+1,r);
return solve(ls[x],ls[y],l,mid);
}
inline bool cmp(int x,int y){
if(sum[rt[x]]==sum[rt[y]])return x<y;
return solve(rt[x],rt[y],1,m);
}
int main(){
freopen("z.in","r",stdin);
freopen("z.out","w",stdout);
scanf("%d%d",&n,&m);pw[0]=1;
for(int i=1;i<=m;i++)pw[i]=pw[i-1]*base;
for(int i=1;i<=n;i++)ans[i]=i;
scanf("%s",s+1);build(rt[1],1,m);
for(int i=2;i<=n;i++){
scanf("%d%d%s",&x,&y,ch+1);
update(rt[x],rt[i],1,m,y,ch[1]);
}
sort(ans+1,ans+1+n,cmp);
for(int i=1;i<=n;i++){
printf("%d",ans[i]);
if(i!=n)printf(" ");
}
return 0;
}