对于两个字符串,我们想知道他们有多少不同的公共子串,不妨可以考虑成对于戴尔个串的每个不同的后缀,有多少个相同后缀子串。
于是可以考虑成为,我们对于第一个串先建立一个后缀树(link树),然后对于第二个串,我们在第一个后缀树上跑,来求答案,但是第二个串要怎么跑呢?我们不妨将第二个串也插入到后缀树上去。
这样以来,两个串都在后缀树上了,给第一个串设立成为“存在点”,将第二个串设置成为“查询点”。然后我们只需要在后缀树上dfs跑一个贡献就可以了,由于后缀树实际上是路径压缩的,所以不同的点的个数,实际上还要乘以它和它的父亲节点(link节点)的长度差值。
1 #include <iostream> 2 #include <cstdio> 3 #include <cmath> 4 #include <string> 5 #include <cstring> 6 #include <algorithm> 7 #include <limits> 8 #include <vector> 9 #include <stack> 10 #include <queue> 11 #include <set> 12 #include <map> 13 #include <bitset> 14 #include <unordered_map> 15 #include <unordered_set> 16 #define lowbit(x) ( x&(-x) ) 17 #define pi 3.141592653589793 18 #define e 2.718281828459045 19 #define INF 0x3f3f3f3f 20 #define HalF (l + r)>>1 21 #define lsn rt<<1 22 #define rsn rt<<1|1 23 #define Lson lsn, l, mid 24 #define Rson rsn, mid+1, r 25 #define QL Lson, ql, qr 26 #define QR Rson, ql, qr 27 #define myself rt, l, r 28 #define pii pair<int, int> 29 #define MP(a, b) make_pair(a, b) 30 using namespace std; 31 typedef unsigned long long ull; 32 typedef unsigned int uit; 33 typedef long long ll; 34 const int maxN = 4e5 + 7; 35 const int maxP = maxN << 1; 36 ll ans; 37 struct SAM 38 { 39 struct state 40 { 41 int len, link, next[26]; 42 } st[maxP]; 43 int siz = 1, last; 44 int dp[maxP] = {0}, query[maxP] = {0}; 45 void init() 46 { 47 siz = last = 1; 48 st[1].len = 0; 49 st[1].link = 0; 50 memset(st[1].next, 0, sizeof(st[1].next)); 51 siz++; 52 } 53 int extend(int c, int val, int question) 54 { 55 if(st[last].next[c] && st[last].len + 1 == st[st[last].next[c]].len) 56 { 57 last = st[last].next[c]; 58 dp[last] += val; 59 query[last] += question; 60 return last; 61 } 62 int cur = siz++; 63 st[cur].len = st[last].len + 1; 64 dp[cur] += val; 65 query[cur] += question; 66 int p = last; 67 while (p && !st[p].next[c]) 68 { 69 st[p].next[c] = cur; 70 p = st[p].link; 71 } 72 if (p == 0) 73 { 74 st[cur].link = 1; 75 } 76 else 77 { 78 int q = st[p].next[c]; 79 if (st[p].len + 1 == st[q].len) 80 { 81 st[cur].link = q; 82 } 83 else 84 { 85 int clone; 86 if(p == last) 87 { 88 clone = cur; 89 } 90 else 91 { 92 clone = siz++; 93 st[cur].link = clone; 94 } 95 st[clone] = st[q]; 96 st[q].link = clone; 97 st[clone].len = st[p].len + 1; 98 while (p != 0 && st[p].next[c] == q) 99 { 100 st[p].next[c] = clone; 101 p = st[p].link; 102 } 103 } 104 } 105 return last = cur; 106 } 107 vector<int> to[maxP]; 108 void bfs() 109 { 110 for(int i=2; i<siz; i++) to[st[i].link].push_back(i); 111 } 112 void dfs(int u) 113 { 114 for(int v : to[u]) 115 { 116 dfs(v); 117 dp[u] += dp[v]; 118 query[u] += query[v]; 119 } 120 if(u ^ 1) ans += 1LL * query[u] * dp[u] * (st[u].len - st[st[u].link].len); 121 } 122 } sam; 123 int main() 124 { 125 int n1, n2; 126 char s[maxN]; 127 scanf("%s", s); 128 n1 = (int)strlen(s); 129 sam.init(); 130 for(int i=0; i<n1; i++) 131 { 132 sam.extend(s[i] - 'a', 1, 0); 133 } 134 scanf("%s", s); 135 n2 = (int)strlen(s); 136 sam.last = 1; 137 for(int i=0; i<n2; i++) 138 { 139 sam.extend(s[i] - 'a', 0, 1); 140 } 141 sam.bfs(); 142 ans = 0; 143 sam.dfs(1); 144 printf("%lld\n", ans); 145 return 0; 146 } 147 /* 148 ababaa 149 aba 150 ans:16 151 */