Colorful String
1 #include <bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int maxn = 1e6+5; 5 char s[maxn]; 6 int n; 7 int record[maxn]; // 记录i结点在原字符串的位置 8 int sum[maxn*4]; 9 void pushup(int rt) { 10 sum[rt] |= sum[rt*2]; 11 sum[rt] |= sum[rt*2+1]; 12 } 13 void build(int l, int r, int rt) { 14 if (l == r) { 15 sum[rt] |= 1<<(s[l]-'a'); 16 return; 17 } 18 int mid = (l+r)/2; 19 build(l,mid,rt*2); 20 build(mid+1,r,rt*2+1); 21 pushup(rt); 22 } 23 int query(int be, int ed, int l, int r, int rt) { 24 if (be <= l && r <= ed) { 25 return sum[rt]; 26 } 27 int mid = (l+r)/2, res = 0; 28 if (be <= mid) res |= query(be,ed,l,mid,rt*2); 29 if (ed > mid) res |= query(be,ed,mid+1,r,rt*2+1); 30 return res; 31 } 32 33 struct PAM { 34 int last; 35 struct Node { 36 ll cnt, len, fail, son[27]; // cnt为以i为结尾的回文子串个数,len为长度 37 Node(int len, int fail) : len(len), fail(fail), cnt(0){ 38 memset(son, 0, sizeof(son)); 39 }; 40 }; 41 vector<Node> st; 42 inline int newnode(int len, int fail = 0) { 43 st.emplace_back(len, fail); 44 return st.size()-1; 45 } 46 inline int getfail(int x, int n) { 47 while (s[n-st[x].len-1] != s[n]) x = st[x].fail; 48 return x; 49 } 50 inline void extend(int c, int i) { 51 int cur = getfail(last, i); 52 if (!st[cur].son[c]) { 53 int nw = newnode(st[cur].len+2, st[getfail(st[cur].fail, i)].son[c]); 54 st[cur].son[c] = nw; 55 } 56 st[ last=st[cur].son[c] ].cnt++; 57 record[last] = i; 58 } 59 void init() { 60 scanf("%s", s+1); 61 n = strlen(s+1); 62 s[0] = 0; 63 newnode(0, 1), newnode(-1); 64 last = 0; 65 for (int i = 1; i <= n; i++) 66 extend(s[i]-'a', i); 67 } 68 ll count() { 69 for (int i = st.size()-1; i >= 0; i--) 70 st[st[i].fail].cnt += st[i].cnt; 71 72 ll ans = n; 73 for (int i = 2; i <= st.size()-1; i++) { 74 if (st[i].len <= 1) continue; 75 76 int L = record[i]-st[i].len+1, R = record[i]; 77 int res = query(L,R,1,n,1); 78 int num = 0; 79 while (res) { 80 if (res&1) num++; 81 res >>= 1; 82 } 83 ans += st[i].cnt*num; 84 } 85 return ans; 86 } 87 }pam; 88 int main() { 89 pam.init(); 90 build(1,n,1); 91 printf("%lld\n",pam.count()); 92 return 0; 93 }