BZOJ4566 Haoi2016 找相同字符【广义后缀自动机】
Description
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两
个子串中有一个位置不同。
Input
两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母
Output
输出一个整数表示答案
Sample Input
aabb
bbaa
Sample Output
10
可以说是模板了
因为只有两个串,所以完全没有必要建trie
直接分别添加,在每个串开始添加之前让last=root就可以了
然后在extend上面加上已经有转移的特判
注意要后缀自动机建完之后再添加right标记
不然会出问题
//Author: dream_maker
#include<bits/stdc++.h>
using namespace std;
//----------------------------------------------
//typename
typedef long long ll;
//convenient for
#define for_up(a, b, c) for (int a = b; a <= c; ++a)
#define for_down(a, b, c) for (int a = b; a >= c; --a)
#define for_vector(a, b) for (int a = 0; a < (signed)b.size(); ++a)
//inf of different typename
const int INF_of_int = 1e9;
const ll INF_of_ll = 1e18;
//fast read and write
template <typename T>
void Read(T &x){
bool w = 1;x = 0;
char c = getchar();
while(!isdigit(c) && c != '-')c = getchar();
if(c == '-')w = 0, c = getchar();
while(isdigit(c)) {
x = (x<<1) + (x<<3) + c -'0';
c = getchar();
}
if(!w)x=-x;
}
template <typename T>
void Write(T x){
if(x < 0) {
putchar('-');
x=-x;
}
if(x > 9)Write(x / 10);
putchar(x % 10 + '0');
}
//----------------------------------------------
const int CHARSET_SIZE = 26;
const int N = 4e5 + 10;
struct Sam {
Sam *ch[CHARSET_SIZE], *prt;
int maxl, right[2];
Sam (int maxl = 0):ch(),prt(NULL),maxl(maxl){
right[0] = right[1] = 0;
}
}pool[N<<1],*cur = pool, *root = new (cur++)Sam, *last = root;
int buc[N];
vector<Sam*> topo;
ll ans = 0;
void extend(int c){
if (last->ch[c] && last->maxl + 1 == last->ch[c]->maxl) {
last = last->ch[c];
return;
}
Sam *u = new (cur++)Sam(last->maxl + 1), *v = last;
for (; v && !v->ch[c]; v = v->prt) v->ch[c] = u;
if (!v) {
u->prt = root;
} else if (v->maxl + 1 == v->ch[c]->maxl) {
u->prt = v->ch[c];
} else {
Sam *n = new (cur++)Sam(v->maxl + 1), *o = v->ch[c];
copy(o->ch, o->ch + CHARSET_SIZE, n->ch);
n->prt = o->prt;
o->prt = u->prt = n;
for (; v && v->ch[c] == o;v = v->prt) v->ch[c] = n;
}
last = u;
}
void toposort(){
int maxv = 0;
for (Sam *p = pool; p != cur; ++p){
maxv = max(maxv, p->maxl);
buc[p->maxl]++;
}
for_up(i, 1, maxv) buc[i] += buc[i-1];
topo.resize(cur-pool);
for (Sam *p = pool; p != cur; ++p)topo[--buc[p->maxl]] = p;
for_down(i, topo.size()-1, 1) {
Sam *p = topo[i];
p->prt->right[0] += p->right[0];
p->prt->right[1] += p->right[1];
ans += 1ll * (p->maxl - p->prt->maxl) * p->right[0] * p->right[1];
}
}
char c[2][N];
int main() {
scanf("%s",c[0] + 1);
int len0 = strlen(c[0] + 1);
for_up(i, 1, len0) extend(c[0][i]-'a');
last = root;
scanf("%s",c[1] + 1);
int len1 = strlen(c[1] + 1);
for_up(i, 1, len1) extend(c[1][i]-'a');
Sam *now = root;
for_up(i, 1, len0) {
now = now->ch[c[0][i]-'a'];
now->right[0] = 1;
}
now = root;
for_up(i, 1, len1) {
now = now->ch[c[1][i]-'a'];
now->right[1] = 1;
}
toposort();
Write(ans);
return 0;
}