poj3376 Finding Palindromes【exKMP】【Trie】

Finding Palindromes
Time Limit: 10000MS   Memory Limit: 262144K
Total Submissions:4710   Accepted: 879
Case Time Limit: 2000MS

Description

A word is called a palindrome if we read from right to left is as same as we read from left to right. For example, "dad", "eye" and "racecar" are all palindromes, but "odd", "see" and "orange" are not palindromes.

Given n strings, you can generate n × n pairs of them and concatenate the pairs into single words. The task is to count how many of the so generated words are palindromes.

Input

The first line of input file contains the number of strings n. The following n lines describe each string:

The i+1-th line contains the length of the i-th string li, then a single space and a string of li small letters of English alphabet.

You can assume that the total length of all strings will not exceed 2,000,000. Two strings in different line may be the same.

Output

Print out only one integer, the number of palindromes.

Sample Input

3
1 a
2 ab
2 ba

Sample Output

5

Hint

The 5 palindromes are: 
aa aba aba abba baab 

Source

POJ Monthly--2007.09.09, Zhou Gelin, modified from POI06
 
题意:
给定n个字符串。问把他们两两组合起来,能有多少个回文串。
思路:
两个字符串$s,t$进行组合有三种情况。
1、$lens=lent$,若$s$和$t$对称,那么可以$st$是一个回文串。
2、$lens>lent$,若$t$和$s$的前缀对称,$s$剩余的部分是一个回文,那么$st$是一个回文串。
3、$lens<lent$,若$s$和$t$的前缀对称,$t$剩余的部分是一个回文,那么$ts$是一个回文串。
现在我们对所有字符串建立一个字典树,然后用每一个反向串在树上进行匹配。
对于第一种可能,很简单,反向串匹配完的那个节点,是$cnt$个字符串的结尾,$ans+=cnt$
对于第二种可能,反向串不能继续匹配了之后,如果剩余的串是是回文,那么也是$ans += cnt$
对于第三种可能,反向串匹配到结尾之后,如果树上这个节点之后有$val$个回文,那么$ans+=val$
因此我们现在需要在Trie树上维护两个值,$cnt$和$val$
 
exKMP我们可以求出某个点之后是不是回文串。
将$s$作为模式串,与反向串$s'$进行exkmp。若$ex[i] = len - i + 1$,那么$s'[i...len]$是一个回文串。
同理将$s'$作为模式串,与$s$进行exkmp。
由此可以在建立Trie树时确定$cnt$和$val$
 
由于题目没有说$n$的范围,只给定了字符串的总长。所以我们只能把所有字符串都存在同一个字符串中,用数组记录开始和结束下标。
 
  1 #include<iostream>
  2 //#include<bits/stdc++.h>
  3 #include<cstdio>
  4 #include<cmath>
  5 //#include<cstdlib>
  6 #include<cstring>
  7 #include<algorithm>
  8 //#include<queue>
  9 #include<vector>
 10 //#include<set>
 11 //#include<climits>
 12 //#include<map>
 13 using namespace std;
 14 typedef long long LL;
 15 typedef unsigned long long ull;
 16 #define N 100010
 17 #define pi 3.1415926535
 18 #define inf 0x3f3f3f3f
 19 
 20 const int maxn = 2e6 + 6;
 21 int n, sum_len;
 22 char s[maxn], s_rev[maxn];
 23 int st[maxn], ed[maxn];
 24 
 25 int nxt[maxn], ex[maxn];
 26 bool flag[2][maxn];
 27 void GETNEXT(char *str, int l, int r)
 28 {
 29     int i=l,j,po;
 30     nxt[i]=r - l + 1;//初始化next[0]
 31     while(str[i]==str[i+1]&&i<=r)//计算next[1]
 32     i++;
 33     nxt[l + 1]=i - l;
 34     po=1 + l;//初始化po的位置
 35     for(i=2+l;i<=r;i++)
 36     {
 37         if(nxt[i-po + l]+i - l<nxt[po]+po - l)//第一种情况,可以直接得到next[i]的值
 38         nxt[i]=nxt[i-po + l];
 39         else//第二种情况,要继续匹配才能得到next[i]的值
 40         {
 41             j=nxt[po]+po-i;
 42             if(j<0)j=0;//如果i>po+next[po],则要从头开始匹配
 43             while(i+j<=r&&str[l+j]==str[j+i])//计算next[i]
 44             j++;
 45             nxt[i]=j;
 46             po=i;//更新po的位置
 47         }
 48     }
 49 }
 50 //计算extend数组
 51 void EXKMP(char *s1,char *s2, int l, int r, int sign)
 52 {
 53     int i=l,j,po;
 54     GETNEXT(s2, l, r);//计算子串的next数组
 55     //for(j = l; j <= r; j++)cout<<nxt[j]<<endl;
 56     while(s1[i]==s2[i]&&i<=r)//计算ex[0]
 57     i++;
 58     ex[l]=i - l;
 59     po=l;//初始化po的位置
 60     for(i=l + 1;i<=r;i++)
 61     {
 62         if(nxt[i-po + l]+i - l<ex[po]+po -l)//第一种情况,直接可以得到ex[i]的值
 63         ex[i]=nxt[i-po+l];
 64         else//第二种情况,要继续匹配才能得到ex[i]的值
 65         {
 66             j=ex[po]+po-i;
 67             if(j<0)j=0;//如果i>ex[po]+po则要从头开始匹配
 68             while(i+j<=r&&s1[j+i]==s2[l+j])//计算ex[i]
 69             j++;
 70             ex[i]=j;
 71             po=i;//更新po的位置
 72         }
 73     }
 74     for(int i = l; i <= r; i++){
 75         //cout<<ex[i]<<endl;
 76         if(ex[i] == r - i + 1){
 77             flag[sign][i] = true;
 78         }
 79     }
 80 }
 81 
 82 struct Trie{
 83     int trie[maxn][26];
 84     int cnt[maxn];
 85     int val[maxn];
 86     int tot;
 87 }tr;
 88 
 89 void init()
 90 {
 91     memset(flag, 0, sizeof(flag));
 92     memset(nxt, 0, sizeof(nxt));
 93     memset(ex, 0, sizeof(ex));
 94     for(int i = 0; i <= tr.tot; i++){
 95         memset(tr.trie[i], 0, sizeof(tr.trie[i]));
 96         tr.cnt[i] = 0;
 97         tr.val[i] = 0;
 98     }
 99     tr.tot = 0;
100 }
101 
102 void insertt(char *str, int l, int r)
103 {
104     int p = 0;
105     for(int k = l; k <= r; k++){
106         int ch = str[k] - 'a';
107         tr.val[p] += flag[0][k];
108         if(tr.trie[p][ch] == 0){
109             tr.trie[p][ch] = ++tr.tot;
110         }
111         p = tr.trie[p][ch];
112 
113     }
114     tr.cnt[p]++;
115 }
116 
117 LL searchh(char *str, int l, int r)
118 {
119     LL ans = 0;
120     int p = 0;
121     for(int k = l; k <= r; k++){
122         //cout<<flag[k]<<endl;
123         int ch = str[k] - 'a';
124         p = tr.trie[p][ch];
125         if(p == 0)break;
126         if(k < r && flag[1][k + 1] || k == r)ans += tr.cnt[p];
127     }
128     if(p)ans += tr.val[p];
129     //ans += tr.cnt[p];
130     return ans;
131 }
132 
133 int main()
134 {
135     while(scanf("%d", &n) != EOF){
136         sum_len = 0;
137         init();
138         for(int i = 0; i < n; i++){
139             int l;
140             scanf("%d %s", &l, s + sum_len);
141             for(int j = 0; j < l; j++){
142                 s_rev[sum_len + j] = s[sum_len + l - 1 - j];
143             }
144             st[i] = sum_len;
145             sum_len += l;
146             ed[i] = sum_len - 1;
147 
148             EXKMP(s, s_rev, st[i], ed[i], 0);
149             EXKMP(s_rev, s, st[i], ed[i], 1);
150             insertt(s, st[i], ed[i]);
151         }
152 
153         LL ans = 0;
154         /*for(int i = 0; i < sum_len; i++){
155             cout<<s[i]<<" "<<flag[0][i]<<" "<<flag[1][i]<<endl;
156         }*/
157         //cout<<s<<endl;
158         for(int i = 0; i < n; i++){
159             ans += searchh(s_rev, st[i], ed[i]);
160         }
161         printf("%lld\n", ans);
162     }
163     return 0;
164 }

 

posted @ 2018-11-21 13:46  wyboooo  阅读(316)  评论(0编辑  收藏  举报