POJ 3376 Finding Palindromes(扩展kmp+trie)
题目链接:http://poj.org/problem?id=3376
题意:给你n个字符串m1、m2、m3...mn 求S = mimj(1=<i,j<=n)是回文串的数量
思路:我们考虑第i个字符串和第j个字符串能构成组合回文串要满足的条件:
1、i的长度小于j,那么i一定是j的反串的前缀,且j的反串剩下的后缀是回文串
2、i的长度等于j,那么i等于j的反串
3、i的长度大于j,那么j的反串一定是i的前缀,且i串剩下的后缀是回文串
我们可以将这n个字符串插入trie,每个节点要维护两个值:value1. 到当前节点的字符串个数;value2. 当前节点后面的回文子串个数
我们用每个字符串的反串去trie上查找,要构成回文串有以下情况:
1、 此反串是其他串的前缀,那么组合回文串的数量就要加上value2
2、此反串的前缀是某些字符串,且反串剩下的后缀是回文串,那么组合回文串的数量要加上value1
3、2的特例:此反串的前缀是某些字符串,且反串剩下的后缀为空,同样要加上value1,这种情况可以和2一起处理
关键:
1、判断字符串的哪些后缀是回文串(用于更新value2),以及对应反串的哪些后缀是回文串(当面临第二种情况时,可直接判断后缀否为回文串)
2、如何更新value1和value2(借助1的结果)
1 #include <cstdio> 2 #include <cstring> 3 #include <algorithm> 4 using namespace std; 5 typedef long long LL; 6 const int MAXN = 2000005; 7 const int KIND = 26; 8 9 struct TrieNode 10 { 11 int num; // 到当前节点的字符串个数 12 int cnt; // 当前节点后面回文子串个数 13 TrieNode* nxt[26]; 14 }; 15 16 TrieNode node[MAXN]; // 避免动态申请空间的时间消耗 17 TrieNode* root; // trie树的根节点 18 int bg[MAXN]; // bg[i]第i+1个字符串开始的位置 19 int ed[MAXN]; // ed[i]第i+1个字符串结束的位置 20 bool flag[2][MAXN]; // flag[0][i]为true表示原串后面为回文串 flag[1][i]表示反串 21 char S[MAXN]; // 存放原串 22 char T[MAXN]; // 存放反串 23 int nxt[MAXN]; // 存放next数组 24 int extend[MAXN]; // 用于判断是否为回文子串 25 LL ans; // 保存结果 26 int tot; // node数组的下标 27 28 void GetNext(char* T, int lhs, int rhs) 29 { 30 int j = 0; 31 while (lhs + j + 1 <= rhs && T[lhs + j] == T[lhs + j + 1]) ++j; 32 nxt[lhs + 1] = j; 33 int k = lhs + 1; 34 for (int i = lhs + 2; i <= rhs; ++i) 35 { 36 int p = nxt[k] + k - 1; 37 int L = nxt[lhs + i - k]; 38 if (L + i < p + 1) nxt[i] = L; 39 else 40 { 41 j = max(0, p - i + 1); 42 while (i + j <= rhs && T[lhs + j] == T[i + j]) ++j; 43 nxt[i] = j; 44 k = i; 45 } 46 } 47 } 48 49 void ExtendKMP(char* S, char* T, int lhs, int rhs, bool sign) 50 { 51 GetNext(T, lhs, rhs); 52 int j = 0; 53 while (j + lhs <= rhs && S[j + lhs] == T[j + lhs]) ++j; 54 extend[lhs] = j; 55 int k = lhs; 56 for (int i = lhs + 1; i <= rhs; ++i) 57 { 58 int p = extend[k] + k - 1; 59 int L = nxt[lhs + i - k]; 60 if (L + i < p + 1) extend[i] = L; 61 else 62 { 63 j = max(0, p - i + 1); 64 while (i + j <= rhs && S[i + j] == T[lhs + j]) ++j; 65 extend[i] = j; 66 k = i; 67 } 68 } 69 for (int i = lhs; i <= rhs; ++i) 70 { 71 if (extend[i] == rhs - i + 1) 72 flag[sign][i] = true; 73 } 74 } 75 76 void Insert(char S[], int lhs, int rhs) 77 { 78 TrieNode* temp = root; 79 for (int i = lhs; i <= rhs; ++i) 80 { 81 int ch = S[i] - 'a'; 82 temp->cnt += flag[0][i]; // 更新当前节点后面回文子串的数目 83 if (temp->nxt[ch] == NULL) temp->nxt[ch] = &node[tot++]; 84 temp = temp->nxt[ch]; 85 } 86 ++temp->num; // 更新到当前节点的字符串数目 87 } 88 89 void Search(char S[], int lhs, int rhs) 90 { 91 TrieNode* temp = root; 92 for (int i = lhs; i <= rhs; ++i) 93 { 94 int ch = S[i] - 'a'; 95 temp = temp->nxt[ch]; 96 if (temp == NULL) break; 97 if ((i < rhs && flag[1][i + 1]) || i == rhs) 98 ans += temp->num; 99 } 100 if (temp) ans += temp->cnt; 101 } 102 103 int main() 104 { 105 int n; 106 while (scanf("%d", &n) != EOF) 107 { 108 // 初始化 109 tot = 0; 110 ans = 0; 111 memset(node, 0, sizeof(node)); 112 memset(flag, 0, sizeof(flag)); 113 root = &node[tot++]; 114 115 int l = 0; 116 int L = 0; 117 for (int i = 0; i < n; ++i) 118 { 119 // 输入一组数据 120 scanf("%d", &l); 121 scanf("%s", S + L); 122 123 // 生成反串 124 for (int j = 0; j < l; ++j) 125 T[L + j] = S[L + l - 1 - j]; 126 127 bg[i] = L; 128 ed[i] = L + l - 1; 129 130 131 ExtendKMP(S, T , bg[i], ed[i], 0); 132 ExtendKMP(T, S , bg[i], ed[i], 1); 133 Insert(S, bg[i], ed[i]); 134 135 L += l; 136 } 137 138 for (int i = 0; i < n; ++i) 139 Search(T, bg[i], ed[i]); 140 141 printf("%lld\n", ans); 142 } 143 return 0; 144 }