CF653F Paper task
题意简述
求一个长度为 \(n\) 的括号串的不同合法括号串的个数。\(n\le 5\times 10^5\)。
Solution
真是 SAM 白学了呢……
看到题目中要求本质不同的子串,那肯定是除了 SAM 没有人能够胜任了。然后考虑到括号串判断合法直接做比较麻烦,所以根据后缀自动机的性质,我们希望能够通过后缀和来反映这个子串是否合法。比如说,我们考虑令 (
为 \(-1\),)
为 \(1\),则只要后缀和都非负,并且整个的和是 \(0\),则是合法的串。
那把 SAM 建立起来,考虑每一个节点。一个节点代表了一个 endpos 等价类。那我们在这些 endpos 中任意取一个,则这个节点所代表的子串就是以 endpos 为结尾,开头在这个节点所代表的 len 的区间内的所有子串。
然后我们现在已经知道需要具体计算哪些子串是合法的,那么就要用到上面提到的转化。我们现假设待判断的串以 \(ed\) 位置结尾,开头在 \([l,r]\) 区间内,\(suf\) 是原串的后缀和,则此区间中合法的子串,应当满足:
\[\forall st\in[p,ed],suf[st]\ge suf[ed+1]\\
suf[st]-suf[ed+1]=0
\]
也就是问题转化成,求区间中有多少个数 \(st\),满足 \([st,ed]\) 中的数都大于某一个数,并且 \(st\) 位置的数等于那个数。直接用 ST 表判断肯定是 TT,所以考虑一个小优化。
你发现,这东西很显然越往前越容易不合法,所以直接二分这个界限就可以了。
所以相当于我们需要求区间中的最小值和区间中等于某个数的个数。直接用 st 表维护,并用 vector 存储每个元素出现的所有下标,然后在每个 vector 内二分出位于当前区间的那段,就可以知道最终答案了。
Code
// Problem: CF653F Paper task
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/CF653F
// Memory Limit: 2600 MB
// Time Limit: 500000 ms
#include<bits/stdc++.h>
#define ll long long
#define inf (1<<30)
#define INF (1ll<<60)
#define pb emplace_back
#define pii pair<int,int>
#define mkp make_pair
#define fi first
#define se second
#define all(a) a.begin(),a.end()
#define siz(a) (int)a.size()
#define clr(a) memset(a,0,sizeof(a))
#define rep(i,j,k) for(int i=(j);i<=(k);i++)
#define per(i,j,k) for(int i=(j);i>=(k);i--)
#define pt(a) cerr<<#a<<'='<<a<<' '
#define pts(a) cerr<<#a<<'='<<a<<'\n'
#define int long long
using namespace std;
const int MAXN=5e5;
struct SAM{int fa,len,ch[2],ed;}tr[MAXN<<1];
int tot=1,lst=1;
void ins(int c){
int p=lst,np=++tot;lst=np;
tr[np].len=tr[p].len+1;
while(p&&!tr[p].ch[c]) tr[p].ch[c]=np,p=tr[p].fa;
if(!p) tr[np].fa=1;
else{
int v=tr[p].ch[c];
if(tr[v].len==tr[p].len+1) tr[np].fa=v;
else{
int nv=++tot;tr[nv]=tr[v];
tr[nv].len=tr[p].len+1;
while(p&&tr[p].ch[c]==v) tr[p].ch[c]=nv,p=tr[p].fa;
tr[v].fa=tr[np].fa=nv;
}
}
}
int suf[MAXN],st[MAXN][20];
int qmin(int l,int r){
int k=__lg(r-l+1);
return min(st[l][k],st[r-(1<<k)+1][k]);
}
vector<int> pos[MAXN<<1],e[MAXN<<1];
int ans,n;
void dfs(int x){
for(int s:e[x])
dfs(s),tr[x].ed=tr[s].ed;
if(x==1) return;
int l=tr[tr[x].fa].len+1,r=tr[x].len,mxl=l-1;
while(l<=r){
int mid=(l+r)>>1;
if(qmin(tr[x].ed-mid+1,tr[x].ed)>=suf[tr[x].ed+1])
l=mid+1,mxl=mid;
else r=mid-1;
}
if(mxl==tr[tr[x].fa].len) return;
int v=suf[tr[x].ed+1];
ans+=upper_bound(all(pos[v+n]),tr[x].ed-tr[tr[x].fa].len)-
lower_bound(all(pos[v+n]),tr[x].ed-mxl+1);
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
string s;cin>>n>>s;
s=' '+s;
rep(i,1,n) ins(s[i]==')');
per(i,n,1) suf[i]=suf[i+1]+(s[i]=='('?-1:1);
rep(i,1,n) pos[suf[i]+n].pb(i),st[i][0]=suf[i];
rep(j,1,19) for(int i=1;i+(1<<(j-1))<=n;i++)
st[i][j]=min(st[i][j-1],st[i+(1<<(j-1))][j-1]);
int p=1;
rep(i,1,n){
p=tr[p].ch[s[i]==')'];
tr[p].ed=i;
}
rep(i,2,tot) e[tr[i].fa].pb(i);
dfs(1);
cout<<ans<<'\n';
return 0;
}