hdu7462-字符串【SAM,二分】

正题

题目链接:https://acm.hdu.edu.cn/showproblem.php?pid=7462


题目大意

你有一个由 \(a,b\) 组成的字符串 \(s\)
\(m\) 个操作:

  • 询问有多少个本质不同的串 \(t\) 使得 \(s[l,r]\)\(t\) 的子串且两个串在 \(s\) 中的出现次数相同。
  • 询问有多少个本质不同的串 \(t\) 使得 \(t\)\(s[l,r]\) 的子串且两个串在 \(s\) 中的出现次数相同。

强制在线

\(\sum |s|\leq 5\times 10^5,\sum q\leq 5\times 10^5\)


解题思路

首先 \(t\) 肯定也是 \(s\) 的子串,所以我们考虑在SAM上解决这个问题。

先建一个SAM,我们先考虑询问1,设询问区间为 \([l,r]\) ,我们先在SAM上找到对应节点(经典方法建立SAM时记录每个后缀所在节点,然后从这个节点在parent树往上倍增跳到 \(len\) 区间为 \(r-l+1\) 的位置)。

那么此时所有出现次数相同的串 \(t=s[x,r]\ (x\leq l)\) 也在同一个节点处,而且我们可以用 \(len_x-(r-l)\) 得到个数。

此时考虑右端点往右扩展的情况,相当于往当前节点SAM字符 \(s_{r+1}\) 的方向走一步,暴力跳肯定是不对的,我们分析一下性质。

假设 \(s_{r+1}=a\) ,假设当前节点 \(x\) 既有字符 \(a\) 的边也有字符 \(b\) 的边,那么说明往下走了之后出现次数肯定会比当前串少,所以就没有往下走的必要了。也就是我们往下走的路径肯定是SAM上的一条出现次数相同的链。

我们可以先处理出每个节点在原串的出现次数,然后把所有这样的链拉出来,询问时直接二分我们能够往后走到哪个位置,顺便使用前缀和记录一下 \(len_x\) 的和即可。

对于询问2操作类似的变为往前调,记录 \(len_{fa_x}\) 的和,但是需要考虑的一点是因为当前节点的长度是一个区间 \(len_{fa_x}\leq r-l+1\leq len_x\),我们往前走到 \(x'\) 时可能存在 \(r-l\leq len_{fa_x'}\) 的情况,也就是对应节点需要往上跳,但是我们考虑如果记 \(c_x\)\(x\) 节点对应串的出现次数,那么有 \(c_{fa_{x'}}>c_{x'}\geq c_{x}\) ,所有如果往上跳了 \(s[l,r-1]\) 出现次数肯定就比 \(s[l,r]\) 多了,就没有继续的必要了,所以我们二分时还需要维护一下中间是否需要往上跳。

时间复杂度:\(O(n\log n)\)

因为是赛时代码很多东西没想明白所以程序会写的比较臃肿。


code

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<cctype>
#define ull unsigned long long
#define ll long long
using namespace std;
const int N=1e6+10;
const ull g=131;
struct node{
	int c,fa;
	ll l,r;
	ull h;
}zero;
int T,n,q,las,cnt,tot;
int fa[N],len[N],ch[N][2],ct[N],f[N][20];
int ed[N],pos[N],seg[N],p[N];
vector<node> v[N];
vector<int> G[N];
ull pw[N],has[N];
char s[N];
int read(){
	int x=0,f=1;char c=getchar();
	while(!isdigit(c)){if(c=='-')f=-f;c=getchar();}
	while(isdigit(c)){x=(x<<1)+(x<<3)+c-'0';c=getchar();}
	return x*f;
}
void ins(int c){
	int p=las;int np=las=++cnt;
	len[np]=len[p]+1;
	for(;p&&!ch[p][c];p=fa[p])ch[p][c]=np;
	if(!p)fa[np]=1;
	else{
		int q=ch[p][c];
		if(len[p]+1==len[q])fa[np]=q;
		else{
			int nq=++cnt;len[nq]=len[p]+1;
			ch[nq][0]=ch[q][0];ch[nq][1]=ch[q][1];
//			memcpy(ch[nq],ch[q],sizeof(ch[nq]));
			fa[nq]=fa[q];fa[np]=fa[q]=nq;
			for(;p&&ch[p][c]==q;p=fa[p])ch[p][c]=nq;
		}
	}
	return;
}
void dfs(int x){
	for(int i=0;i<G[x].size();i++){
		f[G[x][i]][0]=x,dfs(G[x][i]);
		ct[x]+=ct[G[x][i]];
	}
	return;
}
bool cmp(int x,int y)
{return (len[x]==len[y])?(x>y):(len[x]<len[y]);}
int main()
{
	pw[0]=1;
	for(int i=1;i<N;i++)pw[i]=pw[i-1]*g;
	scanf("%d",&T);
	while(T--){
		scanf("%s",s+1);n=strlen(s+1);
		las=cnt=1;
		for(int i=1;i<=n;i++){
			ins(s[i]-'a');
			ed[i]=las;ct[las]++;
			has[i]=has[i-1]*g+s[i]-'a';
		}
		for(int i=2;i<=cnt;i++)
			G[fa[i]].push_back(i);
		dfs(1);
		for(int j=1;j<20;j++)
			for(int i=1;i<=cnt;i++)
				f[i][j]=f[f[i][j-1]][j-1];
		for(int i=1;i<=cnt;i++)p[i]=i;
		sort(p+1,p+1+cnt,cmp);
		for(int i=1;i<=cnt;i++){
			int x=p[i],c=0;
			if(!pos[x])pos[x]=++tot,v[pos[x]].push_back(zero);
			if(ch[x][0]&&ct[x]==ct[ch[x][0]])pos[ch[x][0]]=pos[x],c=0;
			if(ch[x][1]&&ct[x]==ct[ch[x][1]])pos[ch[x][1]]=pos[x],c=1;
			node w;
			w.c=c;w.r=len[x];w.l=len[fa[x]];
			w.fa=len[fa[x]];
			v[pos[x]].push_back(w);
			seg[x]=v[pos[x]].size()-1;
		}
		for(int i=1;i<=tot;i++){
			for(int j=1;j<v[i].size();j++){
				v[i][j].h=v[i][j-1].h*g+v[i][j].c;
				v[i][j].l+=v[i][j-1].l;
				v[i][j].r+=v[i][j-1].r;
			}
		}
		scanf("%d",&q);long long lasans=0;
		while(q--){
			int op=read(),l=read(),r=read();
			l=(l+lasans-1)%n+1;r=(r+lasans-1)%n+1;
			if(op==1){
				int x=ed[r];
				for(int j=19;j>=0;j--)
					if(len[f[x][j]]>=r-l+1)x=f[x][j];
				int id=pos[x],now=seg[x];
				int L=now,R=v[id].size()-2;
				while(L<=R){
					int mid=(L+R)>>1,dl=mid-now+1;
					if(v[id][mid].h-v[id][now-1].h*pw[dl]==has[r+dl]-has[r]*pw[dl])L=mid+1;
					else R=mid-1;
				}
				int dl=L-now;
				long long ans=(v[id][L].r-v[id][now-1].r)-1ll*((r-l)+(r-l+dl))*(dl+1)/2ll;
				printf("%lld\n",ans);lasans=ans%n;
			}else{
				int x=ed[r];
				for(int j=19;j>=0;j--)
					if(len[f[x][j]]>=r-l+1)x=f[x][j];
				int id=pos[x],now=seg[x];
				int L=max(1,now-(r-l+1)),R=now-1;
				while(L<=R){
					int mid=(L+R)>>1,dl=now-mid;
					if(v[id][now-1].h-v[id][mid-1].h*pw[dl]==has[r]-has[r-dl]*pw[dl]
					&&v[id][mid].fa<(r-l+1)-dl)R=mid-1;
					else L=mid+1;
				}
				int dl=now-L;
				long long ans=1ll*((r-l+1)+(r-l+1-dl))*(dl+1)/2ll-(v[id][now].l-v[id][L-1].l);
				printf("%lld\n",ans);lasans=ans%n;
			}
		}
		
		for(int i=1;i<=cnt;i++){
			fa[i]=ch[i][0]=ch[i][1]=len[i]=ct[i]=0;
			pos[i]=seg[i]=ed[i]=p[i]=has[i]=0;
			G[i].clear();
		}
		for(int i=1;i<=tot;i++)v[i].clear();
		tot=cnt=0;
	}
	return 0;
}
posted @ 2024-08-13 10:26  QuantAsk  阅读(13)  评论(0编辑  收藏  举报