Codeforces 235C Cyclical Quest (后缀自动机)

思路:一眼看过去,好像处理出每个字串的最小表示的 \(hash\) 值就可以解决了, 但想了复杂度明显过不去,由于要统计某种子串个数,所以首先想到后缀自动机,然后分析,我们将每次查询的模式串翻倍(接在自身后面),模式串的原本长度为 \(n\) ,假设我们现在在后缀自动机上找到了区间 \((le, ri)\) 的子串,首先判断 \(ri - le + 1\) 是否等于 \(n\) , 若相等则加上该节点 \(ednpoints\) 集合大小,然后我们要查询的就是 $(le + 1, ri + 1) $ 的子串了,首先看子串 \((le + 1, ri)\) 是否属于该节点,若不属于,则沿着 \(link\) 链接向上跳,跳到包含子串 \((le + 1, ri)\) 的节点 \(p\) ,然后判断 \(st[p].next[s[ri + 1]]\) 是否存在,若存在,则 \(p\) 跳到 \(p = st[p].next[s[ri + 1]]\) ,否则 \(p\) 直接跳到 \(st[p].link\) , 并更新对应的 \(le\) 。具体看代码

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn = 1e5 + 50;
struct state {
  int len, link;
  int next[26];
};

state st[maxn * 20];
int sz, last;

void sam_init() {
  st[0].len = 0;
  st[0].link = -1;
  sz = 1;
  last = 0;
}

LL num[maxn * 20];
void sam_extend(int c) {
  int cur = sz++;
  st[cur].len = st[last].len + 1;
  int p = last;
  while (p != -1 && !st[p].next[c]) {
    st[p].next[c] = cur;
    p = st[p].link;
  }
  if (p == -1) {
    st[cur].link = 0;
  } else {
    int q = st[p].next[c];
    if (st[p].len + 1 == st[q].len) {
      st[cur].link = q;
    } else {
      int clone = sz++;
      st[clone].len = st[p].len + 1;
      for(int i = 0; i < 26; i++) st[clone].next[i] = st[q].next[i];
      st[clone].link = st[q].link;
      while (p != -1 && st[p].next[c] == q) {
        st[p].next[c] = clone;
        p = st[p].link;
      }
      st[q].link = st[cur].link = clone;
    }
  }
  last = cur;
}

struct Edge
{
	int to, next;
} edge[maxn * 40];

int k, head[maxn * 20];
void add(int a, int b){
	edge[k].to = b;
	edge[k].next = head[a];
	head[a] = k++;
}

void dfs(int u, int pre){
	for(int i = head[u]; i != -1; i = edge[i].next){
		int to = edge[i].to;
		if(to == pre) continue;
		dfs(to, u);
		num[u] += num[to];
	}
}
string s, t;
int vis[maxn * 20];
int main(int argc, char const *argv[])
{
	cin >> t;
	int tlen = t.size();
	sam_init();
	for(int i = 0; i < tlen; i++){
		sam_extend(t[i] - 'a');
		num[last] = 1;
	}
	for(int i = 0; i < sz; i++) head[i] = -1;
	for(int i = 1; i < sz; i++){
		add(i, st[i].link);
		add(st[i].link, i);
	}

	dfs(0, -1);
	int q;
	scanf("%d", &q);
	int id = 0;
	while(q--){
		id++;
		cin >> s;
		int n = s.size();
		s += s;
		int p = 0;
		int le = 0, ri = 0;
		LL ans = 0;
		while(le < n && ri < 2 * n){
			if(st[p].next[s[ri] - 'a']){
				p = st[p].next[s[ri] - 'a'];
				if(ri - le + 1 == n){
					if(vis[p] != id){ // 记录一下该点的贡献已经加过,防止重复算贡献,比如第二个样例
						vis[p] = id;
						ans += num[p];
					}
					le++;
					while(st[st[p].link].len + 1 > n - 1 && p != 0){
						p = st[p].link;
					}
				}
				ri++;
			} else {
				if(p == 0) le++, ri = le; // 注意,若 p 是节点 0 ,则需要让 le++, 否则会死循环
				p = st[p].link; 
				le = ri - 1 - st[p].len + 1;
			}
		}
		printf("%lld\n", ans);
	}
	return 0;
}

posted @ 2020-08-07 21:09  从小学  阅读(69)  评论(0编辑  收藏  举报