Loading

CF653F Paper task

题意简述

求一个长度为 \(n\) 的括号串的不同合法括号串的个数。\(n\le 5\times 10^5\)

Solution

真是 SAM 白学了呢……

看到题目中要求本质不同的子串,那肯定是除了 SAM 没有人能够胜任了。然后考虑到括号串判断合法直接做比较麻烦,所以根据后缀自动机的性质,我们希望能够通过后缀和来反映这个子串是否合法。比如说,我们考虑令 (\(-1\))\(1\),则只要后缀和都非负,并且整个的和是 \(0\),则是合法的串。

那把 SAM 建立起来,考虑每一个节点。一个节点代表了一个 endpos 等价类。那我们在这些 endpos 中任意取一个,则这个节点所代表的子串就是以 endpos 为结尾,开头在这个节点所代表的 len 的区间内的所有子串。

然后我们现在已经知道需要具体计算哪些子串是合法的,那么就要用到上面提到的转化。我们现假设待判断的串以 \(ed\) 位置结尾,开头在 \([l,r]\) 区间内,\(suf\) 是原串的后缀和,则此区间中合法的子串,应当满足:

\[\forall st\in[p,ed],suf[st]\ge suf[ed+1]\\ suf[st]-suf[ed+1]=0 \]

也就是问题转化成,求区间中有多少个数 \(st\),满足 \([st,ed]\) 中的数都大于某一个数,并且 \(st\) 位置的数等于那个数。直接用 ST 表判断肯定是 TT,所以考虑一个小优化。

你发现,这东西很显然越往前越容易不合法,所以直接二分这个界限就可以了。

所以相当于我们需要求区间中的最小值和区间中等于某个数的个数。直接用 st 表维护,并用 vector 存储每个元素出现的所有下标,然后在每个 vector 内二分出位于当前区间的那段,就可以知道最终答案了。

每日调戏洛谷爬虫

Code

// Problem: CF653F Paper task
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/CF653F
// Memory Limit: 2600 MB
// Time Limit: 500000 ms

#include<bits/stdc++.h>
#define ll long long
#define inf (1<<30)
#define INF (1ll<<60)
#define pb emplace_back
#define pii pair<int,int>
#define mkp make_pair
#define fi first
#define se second
#define all(a) a.begin(),a.end()
#define siz(a) (int)a.size()
#define clr(a) memset(a,0,sizeof(a))
#define rep(i,j,k) for(int i=(j);i<=(k);i++)
#define per(i,j,k) for(int i=(j);i>=(k);i--)
#define pt(a) cerr<<#a<<'='<<a<<' '
#define pts(a) cerr<<#a<<'='<<a<<'\n'
#define int long long
using namespace std;
const int MAXN=5e5;
struct SAM{int fa,len,ch[2],ed;}tr[MAXN<<1];
int tot=1,lst=1;
void ins(int c){
	int p=lst,np=++tot;lst=np;
	tr[np].len=tr[p].len+1;
	while(p&&!tr[p].ch[c]) tr[p].ch[c]=np,p=tr[p].fa;
	if(!p) tr[np].fa=1;
	else{
		int v=tr[p].ch[c];
		if(tr[v].len==tr[p].len+1) tr[np].fa=v;
		else{
			int nv=++tot;tr[nv]=tr[v];
			tr[nv].len=tr[p].len+1;
			while(p&&tr[p].ch[c]==v) tr[p].ch[c]=nv,p=tr[p].fa;
			tr[v].fa=tr[np].fa=nv;
		}
	}
}
int suf[MAXN],st[MAXN][20];
int qmin(int l,int r){
	int k=__lg(r-l+1);
	return min(st[l][k],st[r-(1<<k)+1][k]);
}
vector<int> pos[MAXN<<1],e[MAXN<<1];
int ans,n;
void dfs(int x){
	for(int s:e[x])
		dfs(s),tr[x].ed=tr[s].ed;
	if(x==1) return;
	int l=tr[tr[x].fa].len+1,r=tr[x].len,mxl=l-1;
	while(l<=r){
		int mid=(l+r)>>1;
		if(qmin(tr[x].ed-mid+1,tr[x].ed)>=suf[tr[x].ed+1])
			l=mid+1,mxl=mid;
		else r=mid-1;
	}
	if(mxl==tr[tr[x].fa].len) return;
	int v=suf[tr[x].ed+1];
	ans+=upper_bound(all(pos[v+n]),tr[x].ed-tr[tr[x].fa].len)-
		 lower_bound(all(pos[v+n]),tr[x].ed-mxl+1);
}
signed main()
{
	ios::sync_with_stdio(0);
	cin.tie(0);cout.tie(0);
	string s;cin>>n>>s;
	s=' '+s;
	rep(i,1,n) ins(s[i]==')');
	per(i,n,1) suf[i]=suf[i+1]+(s[i]=='('?-1:1);
	rep(i,1,n) pos[suf[i]+n].pb(i),st[i][0]=suf[i];
	rep(j,1,19) for(int i=1;i+(1<<(j-1))<=n;i++)
		st[i][j]=min(st[i][j-1],st[i+(1<<(j-1))][j-1]);
	int p=1;
	rep(i,1,n){
		p=tr[p].ch[s[i]==')'];
		tr[p].ed=i;
	}
	rep(i,2,tot) e[tr[i].fa].pb(i);
	dfs(1);
	cout<<ans<<'\n';
	return 0;
}
posted @ 2022-08-08 22:12  ZCETHAN  阅读(20)  评论(0编辑  收藏  举报