poj3376 Finding Palindromes【exKMP】【Trie】
Time Limit: 10000MS | Memory Limit: 262144K | |
Total Submissions:4710 | Accepted: 879 | |
Case Time Limit: 2000MS |
Description
A word is called a palindrome if we read from right to left is as same as we read from left to right. For example, "dad", "eye" and "racecar" are all palindromes, but "odd", "see" and "orange" are not palindromes.
Given n strings, you can generate n × n pairs of them and concatenate the pairs into single words. The task is to count how many of the so generated words are palindromes.
Input
The first line of input file contains the number of strings n. The following n lines describe each string:
The i+1-th line contains the length of the i-th string li, then a single space and a string of li small letters of English alphabet.
You can assume that the total length of all strings will not exceed 2,000,000. Two strings in different line may be the same.
Output
Print out only one integer, the number of palindromes.
Sample Input
3 1 a 2 ab 2 ba
Sample Output
5
Hint
aa aba aba abba baab
Source
1 #include<iostream> 2 //#include<bits/stdc++.h> 3 #include<cstdio> 4 #include<cmath> 5 //#include<cstdlib> 6 #include<cstring> 7 #include<algorithm> 8 //#include<queue> 9 #include<vector> 10 //#include<set> 11 //#include<climits> 12 //#include<map> 13 using namespace std; 14 typedef long long LL; 15 typedef unsigned long long ull; 16 #define N 100010 17 #define pi 3.1415926535 18 #define inf 0x3f3f3f3f 19 20 const int maxn = 2e6 + 6; 21 int n, sum_len; 22 char s[maxn], s_rev[maxn]; 23 int st[maxn], ed[maxn]; 24 25 int nxt[maxn], ex[maxn]; 26 bool flag[2][maxn]; 27 void GETNEXT(char *str, int l, int r) 28 { 29 int i=l,j,po; 30 nxt[i]=r - l + 1;//初始化next[0] 31 while(str[i]==str[i+1]&&i<=r)//计算next[1] 32 i++; 33 nxt[l + 1]=i - l; 34 po=1 + l;//初始化po的位置 35 for(i=2+l;i<=r;i++) 36 { 37 if(nxt[i-po + l]+i - l<nxt[po]+po - l)//第一种情况,可以直接得到next[i]的值 38 nxt[i]=nxt[i-po + l]; 39 else//第二种情况,要继续匹配才能得到next[i]的值 40 { 41 j=nxt[po]+po-i; 42 if(j<0)j=0;//如果i>po+next[po],则要从头开始匹配 43 while(i+j<=r&&str[l+j]==str[j+i])//计算next[i] 44 j++; 45 nxt[i]=j; 46 po=i;//更新po的位置 47 } 48 } 49 } 50 //计算extend数组 51 void EXKMP(char *s1,char *s2, int l, int r, int sign) 52 { 53 int i=l,j,po; 54 GETNEXT(s2, l, r);//计算子串的next数组 55 //for(j = l; j <= r; j++)cout<<nxt[j]<<endl; 56 while(s1[i]==s2[i]&&i<=r)//计算ex[0] 57 i++; 58 ex[l]=i - l; 59 po=l;//初始化po的位置 60 for(i=l + 1;i<=r;i++) 61 { 62 if(nxt[i-po + l]+i - l<ex[po]+po -l)//第一种情况,直接可以得到ex[i]的值 63 ex[i]=nxt[i-po+l]; 64 else//第二种情况,要继续匹配才能得到ex[i]的值 65 { 66 j=ex[po]+po-i; 67 if(j<0)j=0;//如果i>ex[po]+po则要从头开始匹配 68 while(i+j<=r&&s1[j+i]==s2[l+j])//计算ex[i] 69 j++; 70 ex[i]=j; 71 po=i;//更新po的位置 72 } 73 } 74 for(int i = l; i <= r; i++){ 75 //cout<<ex[i]<<endl; 76 if(ex[i] == r - i + 1){ 77 flag[sign][i] = true; 78 } 79 } 80 } 81 82 struct Trie{ 83 int trie[maxn][26]; 84 int cnt[maxn]; 85 int val[maxn]; 86 int tot; 87 }tr; 88 89 void init() 90 { 91 memset(flag, 0, sizeof(flag)); 92 memset(nxt, 0, sizeof(nxt)); 93 memset(ex, 0, sizeof(ex)); 94 for(int i = 0; i <= tr.tot; i++){ 95 memset(tr.trie[i], 0, sizeof(tr.trie[i])); 96 tr.cnt[i] = 0; 97 tr.val[i] = 0; 98 } 99 tr.tot = 0; 100 } 101 102 void insertt(char *str, int l, int r) 103 { 104 int p = 0; 105 for(int k = l; k <= r; k++){ 106 int ch = str[k] - 'a'; 107 tr.val[p] += flag[0][k]; 108 if(tr.trie[p][ch] == 0){ 109 tr.trie[p][ch] = ++tr.tot; 110 } 111 p = tr.trie[p][ch]; 112 113 } 114 tr.cnt[p]++; 115 } 116 117 LL searchh(char *str, int l, int r) 118 { 119 LL ans = 0; 120 int p = 0; 121 for(int k = l; k <= r; k++){ 122 //cout<<flag[k]<<endl; 123 int ch = str[k] - 'a'; 124 p = tr.trie[p][ch]; 125 if(p == 0)break; 126 if(k < r && flag[1][k + 1] || k == r)ans += tr.cnt[p]; 127 } 128 if(p)ans += tr.val[p]; 129 //ans += tr.cnt[p]; 130 return ans; 131 } 132 133 int main() 134 { 135 while(scanf("%d", &n) != EOF){ 136 sum_len = 0; 137 init(); 138 for(int i = 0; i < n; i++){ 139 int l; 140 scanf("%d %s", &l, s + sum_len); 141 for(int j = 0; j < l; j++){ 142 s_rev[sum_len + j] = s[sum_len + l - 1 - j]; 143 } 144 st[i] = sum_len; 145 sum_len += l; 146 ed[i] = sum_len - 1; 147 148 EXKMP(s, s_rev, st[i], ed[i], 0); 149 EXKMP(s_rev, s, st[i], ed[i], 1); 150 insertt(s, st[i], ed[i]); 151 } 152 153 LL ans = 0; 154 /*for(int i = 0; i < sum_len; i++){ 155 cout<<s[i]<<" "<<flag[0][i]<<" "<<flag[1][i]<<endl; 156 }*/ 157 //cout<<s<<endl; 158 for(int i = 0; i < n; i++){ 159 ans += searchh(s_rev, st[i], ed[i]); 160 } 161 printf("%lld\n", ans); 162 } 163 return 0; 164 }