[CF235C]Cyclical Quest(SAM)
题面
https://codeforces.com/contest/235/problem/C
题解
前置知识
给出字符串s,要求多次询问给出的字符串t的所有不同的循环同构在s中出现的次数的总和。
首先处理一下这个“不同的循环同构”。相当于首先要把t的最小循环节求出来。有一个结论:\(t_0\)是t的循环节\(\iff t_0+t=t+t_0\)。(+表示字符串拼接)所以从短到长枚举t的前缀,再用字符串哈希判断就可以了。
设t的最小循环节长度是T。设字符串\(t'=t+t\),那么t的所有不同的循环同构就是\(t'[1..|t|],t'[2..|t|+1],…,t'[T..|t|+T-1]\)。
然后需要求它们在S中出现的次数。为此,我们可以先建出s的后缀自动机,预处理出s各个子串的出现个数,然后将\(t'\)放到s的后缀自动机上匹配。如果当前匹配到\(t'[r]\)这一位(\(r{\geq}|t|\)),到达SAM节点u,那么查看这时匹配的长度len。如果len<\(|t|\),那么说明\(t'[r-T+1..r]\)在s中没有出现;否则,不断执行u=fail[u],将len赋值为u的长度,直到len即将<\(|t|\)。所得的u就是\(t'[r-T+1 .. r]\)对应的节点,将u在S中出现的次数计入答案即可。
更具体细节可见代码中SAM.query()部分
时间复杂度方面,考虑len这个值,它的总增加量最多是\(O(|t|)\)(每匹配一位最多+1);同时在执行u=fail[u]的过程中,len至少-1,所以可以说明总共最多执行了\(O(|t|)\)次u=fail[u],也就是每次询问时间复杂度\(O(|t|)\)得到了保证。
总时间复杂度\(O(|s|+\sum{|t|})\)。
代码
#include<bits/stdc++.h>
using namespace std;
#define rg register
#define In inline
#define ll long long
const int N = 1e6;
const ll mod1 = 998244353,mod2 = 1e9 + 7;
const ll base = 29;
namespace ModCalc{
void Inc(ll &x,ll y,ll mod){
x += y;if(x >= mod)x -= mod;
}
void Dec(ll &x,ll y,ll mod){
x -= y;if(x < 0)x += mod;
}
ll Add(ll x,ll y,ll mod){
Inc(x,y,mod);return x;
}
ll Sub(ll x,ll y,ll mod){
Dec(x,y,mod);return x;
}
}
using namespace ModCalc;
In void write(ll x){
if(x < 0)putchar('-'),x = -x;
if(x > 9)write(x / 10);
putchar('0' + x % 10);
}
char s[2*N+5];
struct SAM{
int cnt,last,nx[2*N+5][26],fail[2*N+5],len[2*N+5];
ll num[2*N+5];
void clear(){
fail[0] = -1;
}
void extend(char c){
int id = c - 'a';
int cur = ++cnt,p;
num[cur] = 1;
for(p = last;p != -1 && !nx[p][id];p = fail[p])nx[p][id] = cur;
if(p == -1)fail[cur] = 0;
else{
int q = nx[p][id];
if(len[q] == len[p] + 1)fail[cur] = q;
else{
int clone = ++cnt;
len[clone] = len[p] + 1;
fail[clone] = fail[q];
memcpy(nx[clone],nx[q],sizeof(nx[clone]));
fail[q] = fail[cur] = clone;
for(;p != -1 && nx[p][id] == q;p = fail[p])nx[p][id] = clone;
}
}
last = cur;
}
vector<int>link[2*N+5];
void dfs(int u){
for(rg int i = 0;i < link[u].size();i++){
int v = link[u][i];
dfs(v);
num[u] += num[v];
}
}
void prepro(){
for(rg int i = 1;i <= cnt;i++)link[fail[i]].push_back(i);
dfs(0);
}
ll query(char s[],int L,int n){ //求长度为n的字符串s的所有长度为L的子串在主串中出现的次数之和
int u,i;
ll ans = 0,curlen = 0;
for(i = 1,u = 0;i <= n;i++){
while(u && !nx[u][s[i]-'a'])u = fail[u],curlen = len[u];
if(nx[u][s[i]-'a'])u = nx[u][s[i]-'a'],curlen++;
if(i >= L){
if(curlen >= L){
while(len[fail[u]] >= L)u = fail[u];
ans += num[u];
}
}
}
return ans;
}
}S;
ll pow1[N+5],pow2[N+5];
void prepro(){
pow1[0] = pow2[0] = 1;
for(rg int i = 1;i <= N;i++){
pow1[i] = pow1[i-1] * base % mod1;
pow2[i] = pow2[i-1] * base % mod2;
}
}
struct str{
ll h1,h2,len;
str(){h1 = h2 = len = 0;}
str(ll _h1,ll _h2,ll _len){h1 = _h1,h2 = _h2,len = _len;}
str(char c){h1 = h2 = c - 'a',len = 1;}
In friend str operator + (str a,str b){
return str(Add(a.h1*pow1[b.len]%mod1,b.h1,mod1),Add(a.h2*pow2[b.len]%mod2,b.h2,mod2),a.len + b.len);
}
In friend bool operator == (str a,str b){
return a.h1 == b.h1 && a.h2 == b.h2 && a.len == b.len;
}
}h[N+5];
int calcT(char s[]){ //计算字符串s的最短循环节
h[0] = str();
int n = strlen(s + 1);
for(rg int i = 1;i <= n;i++)h[i] = h[i-1] + str(s[i]);
for(rg int i = 1;i <= n;i++)if((h[i]+h[n]) == (h[n]+h[i]))return i;
}
int main(){
scanf("%s",s + 1);
int n = strlen(s + 1);
S.clear();
for(rg int i = 1;i <= n;i++)S.extend(s[i]);
S.prepro();
int m;
scanf("%d",&m);
prepro();
while(m--){
scanf("%s",s + 1);
int n = strlen(s + 1);
bool b = 1;
int T = calcT(s);
for(rg int i = 1;i < T;i++)s[n+i] = s[i];
write(S.query(s,n,n+T-1));
putchar('\n');
}
}