CF235C Cyclical Quest
题意
给定一个主串\(S\)和\(n\)个询问串,求每个询问串的所有循环同构在主串中出现的次数总和。
相同的循环同构只算一次
题解
sam的其中一个作用就是可以统计某个子串的出现次数,这个很好搞。
我们把询问字符串复制两遍, 设询问串长度为\(m\), 则复制串中所有长度为\(m\)的的子串都是循环同构
对于复制串中的每个位置,我们维护向左的最远匹配,当然这个匹配不用太远,最远到当前这个endpos自身包含长度为\(m\)的串就行
所以匹配过程就是: 初始在1状态,匹配长度为0, 每次向右加入一个字符时,如果上一个状态能转移就转移,使得当前匹配长度加一,否则就不断的跳link,将匹配长度置为\(len[link]\), (这样做没有问题, 对于一个匹配长度和它对应的\(endpos_i\),匹配程度len一定有 \(len_i >= len >len_{i-1}\), 然后对于一个有转移的link,转移后加一, 所以匹配长度为\(link+1\)也没什么问题。
对于每次转移完成,如果当前匹配长度大于等于\(m\),但当前endpos不包含\(m\),我们就不断跳link,并且把匹配长度置为link.len,这也没什么问题,每次转移后,当前endpos应该也是包含匹配长度的?
至于重复的同构,因为所有的同构长度相同,本质相同的同构一定会走到相同的endpos,对于每一个endpos只统计一次答案即可。
实现
#include <iostream>
#include <cstdio>
#include <vector>
#include <string>
#include <set>
#define ll long long
using namespace std;
int read(){
int num=0, flag=1; char c=getchar();
while(!isdigit(c) && c!='-') c=getchar();
if(c=='-') flag=-1, c=getchar();
while(isdigit(c)) num=num*10+c-'0', c=getchar();
return num*flag;
}
int readc(){
char c=getchar();
while(c<'a' || c>'z') c=getchar();
return c-'a';
}
const int N = 1000500;
int n, m;
char s[N];
namespace sam{
struct{
int len, link, siz=0;
int ch[26];
}st[N<<1];
vector<int> ptr[N<<1];
int las, sz;
void init(){
las=1, sz=1;
st[1].len=0, st[1].link=0;
}
void extend(int c){
int cur=++sz, p=las;
st[cur].len=st[las].len+1, st[cur].siz=1;
while(p && !st[p].ch[c])
st[p].ch[c]=cur, p=st[p].link;
if(p){
int nex = st[p].ch[c];
if(st[p].len+1 == st[nex].len){
st[cur].link = nex;
}else{
int clone = ++sz;
st[clone].len=st[p].len+1, st[clone].link = st[nex].link;
for(int i=0; i<26; i++) st[clone].ch[i] = st[nex].ch[i];
st[cur].link=clone, st[nex].link=clone;
while(st[p].ch[c]==nex) st[p].ch[c]=clone, p=st[p].link;
}
}else{
st[cur].link=1;
}
las = cur;
}
void dfsPtr(int x){
for(int i=0; i<ptr[x].size(); i++) {
dfsPtr(ptr[x][i]);
st[x].siz += st[ptr[x][i]].siz;
}
}
void buildPtr(){
// for(int i=1; i<=sz; i++){
// for(int j=0; j<26; j++){
// if(st[i].ch[j])
// printf("%d %d %c\n", i, st[i].ch[j], j+'a');
// }
// }
for(int i=1; i<=sz; i++) {
ptr[st[i].link].push_back(i);
// printf("%d %d %d\n", st[i].link, i, st[i].len);
}
dfsPtr(1);
}
}
string str;
int t[N];
void solve(){
string str1; cin>>str1;
m = str1.size();
for(int i=0; i<2*m; i++){
t[i] = str1[i%m]-'a';
}
set<int> v;
int x = 1; int y = 0;
for(int i=0; i<2*m; i++){
while(true){
if(sam::st[x].ch[t[i]]){
x = sam::st[x].ch[t[i]];
y++;
break;
}
if(x == 1) break;
x = sam::st[x].link;
y = sam::st[x].len;
}
while(sam::st[sam::st[x].link].len >= m) x = sam::st[x].link, y = sam::st[x].len;
if(sam::st[x].len >= m && y>=m) {
if(v.find(x) == v.end()){
v.insert(x);
}
}
}
ll ans = 0;
for(auto i : v){
ans += sam::st[i].siz;
}
printf("%lld\n", ans);
}
int main(){
cin >> str;
for(int i=0; i<str.size(); i++){
s[i+1] = str[i]-'a';
} n = str.size();
// for(int i=1; i<=n; i++) s[i]=readc();
sam::init();
for(int i=1; i<=n; i++) sam::extend(s[i]);
sam::buildPtr();
int T = read();
while(T--){
solve();
}
return 0;
}