Loading

题解 Luogu P6816 [PA2009] Quasi-template

Link

题意

给定一个小写字母串 \(s\),求:

  • 有多少字符串 \(t\) 可以超出头尾地,可重复地覆盖 \(s\)
  • 在上面的条件下,最短的 \(t\);如果有多个,输出字典序最小的。

\(|s| \leq 2 \times 10^5\)

题解

考虑一个 \(t\) 会怎样在 \(s\) 中出现。首先,它会在中间完整地出现若干次,记第一次、最后一次结束的位置分别为 \(L,R\),那么 \(t\) 的某个后缀(可以为空)还会在 \(L\) 的左侧出现,且 \(t\) 的某个前缀(可以为空)还会在 \(R\) 的右侧出现。

考虑中间完整出现的部分。我们记录下 \(\operatorname{endpos}(t)\),容易发现,仅当 \(\operatorname{endpos}(t)\) 中相邻元素的差值不超过 \(t\) 的长度时,\(t\) 可能是合法的。因为我们利用 \(\operatorname{endpos}\) 集合来判断合法,不难想到建立 SAM,对于每个等价类考虑哪些字符串可能合法。经过刚才的讨论,我们得到了等价类中合法的字符串长度的一个下界。

接下来考虑 \(L\) 左侧的部分。我们发现 \(t\)\(L\) 处完整地出现了一次,且 \(t\) 的某个后缀 \(t'\) 还会在 \(L\) 左侧出现一次,于是 \(t'\) 是前缀 \(L\) 的前缀,也是前缀 \(L\) 的后缀,因此 \(t'\) 是前缀 \(L\) 的一个 border。记 \(L\) 的最长 border 长度为 \(bp_L\),我们发现 \(t\) 合法当且仅当 \(|t| \geq L - bp_L\),否则 \(t\) 无法覆盖 \(L\) 前面的部分。

然后考虑 \(R\) 右侧的部分。我们发现因为维护的是 \(\operatorname{endpos}\) 集合,因此这一部分和 \(L\) 左侧的处理并不完全对称。

image

考虑图中的串 \(t\)。我们可以分析出 \(pre(t)\) 是后缀 \(i\) 的 border,从而得到串 \([i,R]\) 合法需要 \(R \geq n - bs_{i}\),其中 \(bs_i\) 为后缀 \(i\) 的最长 border 长度。

但是,看上去它存在反例:\(bs_i\) 可能比 \(R-i+1\) 更大。

image

根据经典结论,此时 border 长度大于字符串长度一半,则后缀 \(i\) 存在周期,故而 \(R\) 不是该等价类中最后一个 endpos,矛盾。于是这样的串根本不会被计算到。

现在字符串长度存在三个下界,分别来源于其等价类,相邻 endpos 的差值,以及 \(L\) 左侧部分的限制。同时,字符串长度存在上界,来源于其等价类。设下界、上界分别为 \(l,r\),如果 \(l > r\),则该等价类无解。

接下来考虑 \(R\) 右侧的限制,即我们需要查询满足 \(i \in [R-r+1,R-l+1]\)\(i\) 有多少满足 \(n-bs_i \leq R\),这一部分可以主席树完成。

于是第一问已经解决了:建立 SAM 并线段树合并求出每个等价类 endpos 集合,正串反串分别 KMP 求出前缀后缀的最长 border,主席树维护 \(n - bs_i\),最后区间查询求出答案。

对于第二问,求出最小长度是简单的:要求区间中最后一个小于等于 \(R\) 的位置,可以主席树维护序列上 \([1,i]\) 中最后一个小于等于 \(R\) 的位置。区间查询时只需要找到右端点对应的主席树查询,看这个位置是否大于等于左端点,即可。

对于多个长度相同的串,考虑记录下它们的起始位置,利用 SA 或二分 + Hash 来排序,取排名最小的串即可。

# include <bits/stdc++.h>

const int N=400010,INF=0x3f3f3f3f;
const int MN=N*20;

inline int read(void){
	int res,f=1;
	char c;
	while((c=getchar())<'0'||c>'9')
		if(c=='-')f=-1;
	res=c-48;
	while((c=getchar())>='0'&&c<='9')
		res=res*10+c-48;
	return res*f;
}

char s[N],rs[N];
int n;
int rt[N];

namespace sgt{
	struct Node{
		int lc,rc,mg,lp,rp;
		inline void copy(const Node &rhs){
			mg=rhs.mg,lp=rhs.lp,rp=rhs.rp;
			return;
		}
	}tr[MN];
	int cnt;
	inline void pushup(Node &cur,const Node &ls,const Node &rs){
		if(!ls.lp) return cur.copy(rs),void();
		if(!rs.rp) return cur.copy(ls),void();
		cur.lp=ls.lp,cur.rp=rs.rp;
		cur.mg=std::max(std::max(ls.mg,rs.mg),rs.lp-ls.rp);
		return;
	}
	inline int& lc(int x){
		return tr[x].lc;
	}
	inline int& rc(int x){
		return tr[x].rc;
	}
	void ins(int &k,int l,int r,int x){
		if(!k) k=++cnt;
		if(l==r) return tr[k].lp=tr[k].rp=x,void();
		int mid=(l+r)>>1;
		if(x<=mid) ins(lc(k),l,mid,x);
		else ins(rc(k),mid+1,r,x);
		pushup(tr[k],tr[lc(k)],tr[rc(k)]);
		return;
	}
	void merge(int &k,int x,int y,int l,int r){
		if(!x||!y) return k=x|y,void();
		k=++cnt;
		if(l==r) return assert(false),void();
		int mid=(l+r)>>1;
		merge(lc(k),lc(x),lc(y),l,mid),merge(rc(k),rc(x),rc(y),mid+1,r);
		pushup(tr[k],tr[lc(k)],tr[rc(k)]);
		return;
	}
	void debug(int k,int l,int r){
		if(!k) return;
		if(l==r){
			printf("%d ",l);
			return;
		}
		int mid=(l+r)>>1;
		debug(lc(k),l,mid),debug(rc(k),mid+1,r);
		return;
	}
}


namespace kmp{
	int bor[N],rbor[N];
	inline void getb(int *b,char *str){
		int mlen=0;
		for(int i=2;i<=n;++i){
			while(mlen&&str[mlen+1]!=str[i]) mlen=b[mlen];
			if(str[mlen+1]==str[i]) ++mlen;
			b[i]=mlen;
		}
		return;
	}
	inline void init(void){
		getb(bor,s),getb(rbor,rs);
		std::reverse(rbor+1,rbor+1+n);
		return;
	}
}

using namespace kmp;

namespace ssgt{
	int rt[N];
	int cnt;
	struct Node{
		int sum,lc,rc,mpos;
	}tr[MN];
	inline int& lc(int x){
		return tr[x].lc;
	}
	inline int& rc(int x){
		return tr[x].rc;
	}
	void change(int &k,int lst,int l,int r,int x,int pos){
		k=++cnt,tr[k]=tr[lst],++tr[k].sum,tr[k].mpos=pos;
		if(l==r) return;
		int mid=(l+r)>>1;
		if(x<=mid) change(lc(k),lc(lst),l,mid,x,pos);
		else change(rc(k),rc(lst),mid+1,r,x,pos);
		return;
	}
	int lastpos(int &k,int l,int r,int L,int R){
		if(!k) return 0;
		if(L<=l&&r<=R) return tr[k].mpos;
		int mid=(l+r)>>1,mx=0;
		if(L<=mid) mx=lastpos(lc(k),l,mid,L,R);
		if(mid<R) mx=std::max(mx,lastpos(rc(k),mid+1,r,L,R));
		return mx;
	}
	int qsum(int &rt,int &lt,int l,int r,int L,int R){
		if(L<=l&&r<=R) return tr[rt].sum-tr[lt].sum;
		int mid=(l+r)>>1,res=0;
		if(L<=mid) res+=qsum(lc(rt),lc(lt),l,mid,L,R);
		if(mid<R) res+=qsum(rc(rt),rc(lt),mid+1,r,L,R);
		return res; 
	}
	inline void init(void){
		for(int i=1;i<=n;++i) change(rt[i],rt[i-1],1,n,n-rbor[i],i);
		return;
	}
}

namespace sa{
	int num,fir[N],sec[N],t[N],cnt,sa[N],rank[N];
	inline void build_sa(void){
		num=128;
		for(int i=1;i<=n;++i) ++t[fir[i]=s[i]];
		for(int i=1;i<=num;++i) t[i]+=t[i-1];
		for(int i=n;i;--i) sa[t[fir[i]]--]=i;
		for(int k=1;k<=n;k<<=1){
			cnt=0;
			for(int i=n-k+1;i<=n;++i) sec[++cnt]=i;
			for(int i=1;i<=n;++i) if(sa[i]>k) sec[++cnt]=sa[i]-k;
			std::fill(t+1,t+1+num,0);
			for(int i=1;i<=n;++i) ++t[fir[i]];
			for(int i=1;i<=num;++i) t[i]+=t[i-1];
			for(int i=n;i;--i) sa[t[fir[sec[i]]]--]=sec[i];
			std::swap(sec,fir),cnt=1,fir[sa[1]]=1;
			for(int i=2;i<=n;++i) fir[sa[i]]=((sec[sa[i-1]]==sec[sa[i]]&&sec[sa[i-1]+k]==sec[sa[i]+k])?cnt:++cnt);
			if(cnt==n) break;
			num=cnt;
		}
		for(int i=1;i<=n;++i) rank[sa[i]]=i;
		return;
	}
}

namespace sam{
	struct Node{
		int ch[26],link,len;
	}s[N];
	int cnt=1,lst=1;
	inline void extend(int c){
		int cur=++cnt,p=lst,q,clone;
		s[cur].len=s[lst].len+1,lst=cur;
		while(p&&!s[p].ch[c]) s[p].ch[c]=cur,p=s[p].link;
		if(!p) return s[cur].link=1,void();
		q=s[p].ch[c];
		if(s[p].len+1==s[q].len) return s[cur].link=q,void();
		clone=++cnt,s[clone]=s[q],s[clone].len=s[p].len+1;
		s[q].link=s[cur].link=clone;
		while(p&&s[p].ch[c]==q) s[p].ch[c]=clone,p=s[p].link;
		return;
	}
	int p[N],sum[N];
	inline void init(void){
		for(int i=1;i<=n;++i){
			extend(::s[i]-'a'),sgt::ins(rt[lst],1,n,i);
		}
		for(int i=1;i<=cnt;++i) ++sum[s[i].len];
		for(int i=1;i<=n;++i) sum[i]+=sum[i-1];
		for(int i=1;i<=cnt;++i) p[sum[s[i].len]--]=i;
		for(int i=cnt;i;--i) if(p[i]!=1) sgt::merge(rt[s[p[i]].link],rt[s[p[i]].link],rt[p[i]],1,n);
		return;
	}
	inline void solve(void){
		init();
		long long ans=0;
		int minlen=INF;
		std::vector <int> sp;
		for(int i=2;i<=cnt;++i){
			auto tnode=sgt::tr[rt[i]];
			int L=std::max(tnode.mg,s[s[i].link].len+1),R=s[i].len,lp=tnode.lp,rp=tnode.rp,res,pos;
			L=std::max(L,lp-bor[lp]);
			if(L>R) continue;
			ans+=(res=ssgt::qsum(ssgt::rt[rp-L+1],ssgt::rt[rp-R],1,n,1,rp)),
			pos=ssgt::lastpos(ssgt::rt[rp-L+1],1,n,1,rp);
			if(!res) continue;
			if(rp-pos+1<minlen) sp.clear(),sp.push_back(pos),minlen=rp-pos+1;
			else if(rp-pos+1==minlen) sp.push_back(pos);
		}
		assert(sp.size());
		printf("%lld\n",ans);
		int rpos=sp[0];
		for(auto v:sp) if(sa::rank[v]<sa::rank[rpos]) rpos=v;
		for(int i=rpos;i<=rpos+minlen-1;++i) putchar(::s[i]);
		return;
	}
}


int main(void){
	scanf("%s",s+1),n=strlen(s+1);
	for(int i=1;i<=n;++i) rs[i]=s[n-i+1];
	kmp::init(),ssgt::init(),sa::build_sa();
	sam::solve();

	return 0;
}

posted @ 2023-08-01 10:26  Meatherm  阅读(45)  评论(2编辑  收藏  举报