【湖北省选集训 2023】日记 题解

一个更直接的想法。

首先考虑如何去重,这里的手法基本上都一样:钦定前缀是极长的,那么设你拿长度为 \(i\) 的前缀和长度为 \(j\) 的后缀去拼成一个串 \(P=S_{[1,i]}+S_{[n-j+1,n]}\),我们只需要统计 \(S_{i+1}\neq S_{n-j+1}\)\(P\) 就好了。

接下来考虑包含 \(T\) 的条件,最简单的情况是长度为 \(i\) 的前缀或长度为 \(j\) 的后缀已经包含了 \(T\)。找到 \(T\)\(S\) 中的第一次出现和最后一次出现,我们容易计算出这一部分对答案的贡献,注意这里需要开个桶处理一下 \(S_{i+1}\neq S_{n-j+1}\) 的限制。

接下来考虑 \(T\) 不在前缀后缀中出现但在 \(P\) 中出现了的情况。先考虑怎么得到一个 \(O(n^2)\) 的做法:我们对于每一个前缀,求出最长的 \(T\) 的前缀满足其是该前缀的后缀,即 KMP 算法中维护的最大匹配长度。设其为 \(x\),则我们知道跟这个前缀的某个后缀匹配的所有的 \(T\) 的前缀是 \(\{x,nxt_x,nxt_{nxt_x}\dots\}\),即 \(x\)\(T\)\(nxt\) 树上到根的链。同理用 \(T\) 的反串的 \(nxt\) 树可以刻画出与每个后缀的匹配关系。那么问前缀后缀拼起来包不包括 \(T\) 相当于是在问这两颗树上的到根的链中是否存在一对点满足它们的长度之和为 \(|T|\)。直接预处理复杂度就是 \(O(n^2)\) 了。

接下来依然先枚举前缀,然后考虑所有与这个前缀的后缀匹配的 \(T\) 的前缀 \(x\),你发现,由于此时 \(T_{x+1}\) 需要与你选出的后缀的第一个字符匹配,为了满足 \(S_{i+1}\neq S_{n-j+1}\),就有 \(T_{x+1}\neq S_{i+1}\)。这令我们回想起了用 KMP 偏序 Z 函数中的处理方式。你可以看作用 \(T\) 去跟 \(S\) 的每一个前缀求 LCP,也就是说,对于所有前缀 \(i\),其可能满足 \(T_{x+1}\neq S_{i+1}\)\(T_{[1,x]}=S_{[i-x+1,i]}\)\(x\) 的总个数是 \(O(n)\) 的!

我们可以直接跑 Z 函数,或者用 KMP 求 Z 的方法优化暴力跳,都可以得到每个前缀 \(i\) 对应的所有 \(x\),然后你需要在 \(T\) 反串 \(nxt\) 树中找到 \(|T|-x\) 对应的子树,接下来你需要求出这些子树并中的权值之和,权值定义为每个节点匹配上的后缀的个数。

求子树并权值和看起来甩不掉 \(\log\) 因子?我们可以离线!直接开个桶,然后计算每一个权值被多少个子树并算过贡献即可。时间复杂度 \(O(n)\)

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N=5000003,M=10000003;
char s[N],t[M];
int n,m;
int nxt[N],sz[N];
int nx[N],jump[N];
int sum,w[26];
ll res;
int hd[N],ver[M],tran[M],tot;
void add(int u,int v){tran[++tot]=hd[u];hd[u]=tot;ver[tot]=v;}
int cnt[N],num;
void dfs(int u){
	res+=num*sz[u];
	for(int i=hd[u];i;i=tran[i]){
		int v=ver[i];
		if(v<0){if(!cnt[-v]++) ++num,res+=sz[u];}
		else dfs(v);
	}
	for(int i=hd[u];i;i=tran[i]){
		int v=ver[i];
		if(v<0){if(!--cnt[-v]) --num;}
		else break;
	}
}
int main(){
	scanf("%s%s",s+1,t+1);
	n=strlen(s+1);m=strlen(t+1);
	reverse(s+1,s+n+1);reverse(t+1,t+m+1);
	int tt=min(m-1,n);
	for(int i=2,j=0;i<=tt;++i){
		while(j&&t[j+1]!=t[i]) j=nxt[j];
		if(t[j+1]==t[i]) ++j;
		nxt[i]=j;
	}
	int pos=n+1;
	for(int i=1,j=0;i<=n;++i){
		while(j&&t[j+1]!=s[i]) j=nxt[j];
		if(t[j+1]==s[i]) ++j;
		if(j==m){pos=i;break;}
		if(j) ++sz[j];
	}
	for(int i=2;i<=tt;++i) if(nxt[i]) add(nxt[i],i);
	reverse(s+1,s+n+1);reverse(t+1,t+m+1);
	for(int i=2,j=0;i<=tt;++i){
		while(j&&t[j+1]!=t[i]) j=nx[j];
		if(t[j+1]==t[i]) ++j;
		nx[i]=j;
		if(t[i+1]==t[j+1]) jump[i]=jump[j];
		else jump[i]=j;
	}
	int lim=n+1;
	s[n+1]='#';
	for(int i=1,j=0;i<=n;++i){
		while(j&&t[j+1]!=s[i]) j=nx[j];
		if(t[j+1]==s[i]) ++j;
		if(j==m){lim=i;break;}
		for(int x=j;x;)
			if(t[x+1]!=s[i+1]){
				if(m-x<=tt) add(m-x,-i);
				x=nx[x];
			}
			else x=jump[x];
	}
	for(int i=1;i<=tt;++i) if(!nxt[i]) dfs(i);
	sum=pos;
	for(int i=1;i<pos;++i) ++w[s[n-i+1]-97];
	for(int i=lim;i<=n;++i){res+=sum;if(i<n) res-=w[s[i+1]-97];}
	memset(w,0,104);
	sum=n-pos+1;
	for(int i=pos;i<=n;++i) ++w[s[n-i+1]-97];
	for(int i=0;i<=n;++i){res+=sum;if(i<n) res-=w[s[i+1]-97];}
	printf("%lld\n",res);
	return 0;
}
posted @ 2024-02-28 22:29  yyyyxh  阅读(46)  评论(0编辑  收藏  举报