题解

SAM+01Trie树合并的模板题(两样东西我都不太会,写了我一下午,我太菜了5555……)

先考虑反着建SAM,得到的fail树就是原串的后缀树

后缀树上两个点的LCA的endpos集合的最大长度就是这两个点的LCP长度

这样我们就解决了第一个值LCP(x,y)

考虑dfs一遍后缀树,那么每到一个节点,它的子树中的所有点对的贡献中的LCP(x,y)都是固定的了

如何求一个子树中所有节点权值的两两异或最大值呢?

用01Trie树来进行贪心,先尽量走不同方向,再走相同方向

但是如果在每个节点都开一个01Trie树并把所有子节点都加进来,复杂度会爆炸

我们考虑把所有儿子的01Trie树来合并到一起

合并的时候可以不用新建节点来可持久化(所以为什么放在了可持久化作业里a。。。)

(如果建了新点就是可持久化01Trie树合并,空间复杂度为O(nlogn))

(事实上大部分题都是需要新建点的,否则会出许多奇奇怪怪的问题)

这样的复杂度分析类似于线段树合并,是O(nlogn)的

因为我们只有在两棵线段树都存在节点的时候合并两棵线段树的答案

一次合并的复杂度其实是较小的线段树的节点个数(查询的复杂度也是)

查询的时候应该在合并之前查询,否则时间复杂度会从均摊的O(nlogn)退化为O(n^2)

 

这是有的人就会问了,“一次合并的复杂度其实是较小的线段树的节点个数”,那么总的合并次数就应该是O(nlogn)的

而set启发式合并也是只合并了O(nlogn)次,那为什么set合并的复杂度为O(nlog^2n)呢?

因为线段树合并一次是均摊O(minsiz*c)的(minsiz为较小的点集的大小,c是常数)

而set合并是一个一个暴力插入的是O(minsiz*log(maxsiz))的

 

各种合并算法的时间复杂度分析

这里不得不提到一些关于合并的问题

长链剖分合并两条链时的合并次数是均摊O(n)的

因为一条长度为5的链合并到一条长度为10的链得到的是长度为10的链

也就是说每条链只会在向链顶父亲合并时会贡献O(链长)的合并次数

而总链长为O(n),时间复杂度也就是O(n*一次insert的复杂度)

 

重链剖分合并两棵子树是的合并次数是均摊O(nlogn)的

因为一棵大小为5的子树合并到一棵大小为10的子树得到的是大小为15的子树

也就是说每个top点只会在向其父亲合并时会贡献O(子树大小)的合并次数

均摊到每个点的身上,就相当于每个点都向上爬链,一直爬到根,经过的轻边次数就是它在合并中贡献的合并次数

由于一个点到根路径最多经过O(logn)条轻边,所以总合并次数就是均摊O(nlogn)

总时间复杂度就是O(nlogn*一次insert的复杂度)

其实启发式合并的原理与这个大同小异

 

 

本题代码:

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define N 200005
int fa[N],ch[N][26],mxlen[N],id[N],last,sz;
void extend(int x,int pos)
{
	int np,p,nq,q;
	p=last;np=++sz;
	mxlen[np]=mxlen[p]+1;id[np]=pos;
	for(;p!=-1&&!ch[p][x];p=fa[p])
		ch[p][x]=np;
	if(p==-1)fa[np]=0;
	else{
		q=ch[p][x];
		if(mxlen[q]==mxlen[p]+1)fa[np]=q;
		else{
			nq=++sz;
			mxlen[nq]=mxlen[p]+1;
			memcpy(ch[nq],ch[q],sizeof(ch[q]));fa[nq]=fa[q];
			for(;p!=-1&&ch[p][x]==q;p=fa[p])ch[p][x]=nq;
			fa[np]=fa[q]=nq;
		}
	}
	last=np;
}
char s[N];int w[N];
int fir[N],to[N],nxt[N],cnt;
void adde(int a,int b){to[++cnt]=b;nxt[cnt]=fir[a];fir[a]=cnt;}
struct node{
	int ch[2];
}a[N<<5];
int T[N],tot;
void insert(int &i,int x,int d)
{
	if(!i)i=++tot;
	if(d==-1)return;
	insert(a[i].ch[(x>>d)&1],x,d-1);
}
int ans;
void query(int i,int j,int sum,int d)
{
	if(d==-1){ans=max(sum,ans);return;}
	if((a[i].ch[0]&&a[j].ch[1])||(a[i].ch[1]&&a[j].ch[0])){
		if(a[i].ch[0]&&a[j].ch[1])
			query(a[i].ch[0],a[j].ch[1],sum+(1<<d),d-1);
		if(a[i].ch[1]&&a[j].ch[0])
			query(a[i].ch[1],a[j].ch[0],sum+(1<<d),d-1);
	}
	else{
		if(a[i].ch[0]&&a[j].ch[0])
			query(a[i].ch[0],a[j].ch[0],sum,d-1);
		if(a[i].ch[1]&&a[j].ch[1])
			query(a[i].ch[1],a[j].ch[1],sum,d-1);
	}
}
int merge(int i,int j)
{
	if(!i||!j)return i+j;
	a[i].ch[0]=merge(a[i].ch[0],a[j].ch[0]);
	a[i].ch[1]=merge(a[i].ch[1],a[j].ch[1]);
	return i;
}
int siz[N];
void solve(int u)
{
	if(id[u])insert(T[u],w[id[u]],16),siz[u]=1;
	for(int v,p=fir[u];p;p=nxt[p]){
		v=to[p];solve(v);
		//if(siz[u]<siz[v])swap(T[u],T[v]);
		query(T[u],T[v],mxlen[u],16);
		T[u]=merge(T[u],T[v]);siz[u]+=siz[v];
	}
}
int main()
{
	last=0;fa[0]=-1;
	int n,i;
	scanf("%d%s",&n,s+1);
	for(i=n;i>=1;i--)extend(s[i]-'a',i);
	for(i=1;i<=n;i++)scanf("%d",&w[i]);
	for(i=1;i<=sz;i++)adde(fa[i],i);
	solve(0);
	printf("%d\n",ans);
}