CF235C Cyclical Quest

题意

给定一个主串\(S\)\(n\)个询问串,求每个询问串的所有循环同构在主串中出现的次数总和。
相同的循环同构只算一次

题解

sam的其中一个作用就是可以统计某个子串的出现次数,这个很好搞。

我们把询问字符串复制两遍, 设询问串长度为\(m\), 则复制串中所有长度为\(m\)的的子串都是循环同构

对于复制串中的每个位置,我们维护向左的最远匹配,当然这个匹配不用太远,最远到当前这个endpos自身包含长度为\(m\)的串就行

所以匹配过程就是: 初始在1状态,匹配长度为0, 每次向右加入一个字符时,如果上一个状态能转移就转移,使得当前匹配长度加一,否则就不断的跳link,将匹配长度置为\(len[link]\), (这样做没有问题, 对于一个匹配长度和它对应的\(endpos_i\),匹配程度len一定有 \(len_i >= len >len_{i-1}\), 然后对于一个有转移的link,转移后加一, 所以匹配长度为\(link+1\)也没什么问题。

对于每次转移完成,如果当前匹配长度大于等于\(m\),但当前endpos不包含\(m\),我们就不断跳link,并且把匹配长度置为link.len,这也没什么问题,每次转移后,当前endpos应该也是包含匹配长度的?

至于重复的同构,因为所有的同构长度相同,本质相同的同构一定会走到相同的endpos,对于每一个endpos只统计一次答案即可。

实现

#include <iostream>
#include <cstdio>
#include <vector>
#include <string>
#include <set>
#define ll long long
using namespace std;

int read(){
	int num=0, flag=1; char c=getchar();
	while(!isdigit(c) && c!='-') c=getchar();
	if(c=='-') flag=-1, c=getchar();
	while(isdigit(c)) num=num*10+c-'0', c=getchar();
	return num*flag; 
}

int readc(){
	char c=getchar();
	while(c<'a' || c>'z') c=getchar();
	return c-'a'; 
}

const int N = 1000500;
int n, m;
char s[N];


namespace sam{
	struct{
		int len, link, siz=0;
		int ch[26];
	}st[N<<1];
	vector<int> ptr[N<<1];
	int las, sz;
	
	void init(){
		las=1, sz=1;
		st[1].len=0, st[1].link=0; 
	}
	
	void extend(int c){
		int cur=++sz, p=las;
		st[cur].len=st[las].len+1, st[cur].siz=1;
		while(p && !st[p].ch[c]) 
			st[p].ch[c]=cur, p=st[p].link;
		
		if(p){
			int nex = st[p].ch[c];
			if(st[p].len+1 == st[nex].len){
				st[cur].link = nex;
			}else{
				int clone = ++sz;
				st[clone].len=st[p].len+1, st[clone].link = st[nex].link;
				for(int i=0; i<26; i++) st[clone].ch[i] = st[nex].ch[i];
				st[cur].link=clone, st[nex].link=clone;
				
				while(st[p].ch[c]==nex) st[p].ch[c]=clone, p=st[p].link;
			}
		}else{
			st[cur].link=1;
		}
		
		las = cur;
	}
	
	void dfsPtr(int x){
		for(int i=0; i<ptr[x].size(); i++) {
			dfsPtr(ptr[x][i]);
			st[x].siz += st[ptr[x][i]].siz;
		}
	}
	
	void buildPtr(){
        // for(int i=1; i<=sz; i++){
        //     for(int j=0; j<26; j++){
        //         if(st[i].ch[j])
        //             printf("%d %d %c\n", i, st[i].ch[j], j+'a');
        //     }
        // }

		for(int i=1; i<=sz; i++) {
			ptr[st[i].link].push_back(i);
			// printf("%d %d %d\n", st[i].link, i, st[i].len); 
		}
		dfsPtr(1);
	}
	

}

string str;
int t[N];

void solve(){
    string str1; cin>>str1;
    m = str1.size();
    for(int i=0; i<2*m; i++){
        t[i] = str1[i%m]-'a';
    }
    set<int> v;

    int x = 1; int y = 0;
    for(int i=0; i<2*m; i++){
        while(true){
            if(sam::st[x].ch[t[i]]){
                x = sam::st[x].ch[t[i]];
                y++;
                break;
            }

            if(x == 1) break;
            x = sam::st[x].link;
            y = sam::st[x].len;
        }

        while(sam::st[sam::st[x].link].len >= m) x = sam::st[x].link, y = sam::st[x].len;

        if(sam::st[x].len >= m && y>=m) {
            if(v.find(x) == v.end()){
                v.insert(x);
            }
        }
    }

    ll ans = 0;
    for(auto i : v){
        ans += sam::st[i].siz;
    }
    printf("%lld\n", ans);

}

int main(){
	cin >> str;
	for(int i=0; i<str.size(); i++){
		s[i+1] = str[i]-'a';
	} n = str.size();
//	for(int i=1; i<=n; i++) s[i]=readc();
	sam::init();
	for(int i=1; i<=n; i++) sam::extend(s[i]); 
    sam::buildPtr();
    
    int T = read();
    while(T--){
        solve();
    }
	
	return 0;
}

posted @ 2024-08-03 12:24  ltdJcoder  阅读(3)  评论(0编辑  收藏  举报