F. Birthday Cake

2021 Shandong Provincial Collegiate Programming Contest

传送门

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

using ull = unsigned long long ;
using Hash = pair<ull, ull>;

#ifdef LOCAL
#include <debugger>
#else
#define debug(...) 42
#endif

mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());

template <typename T> void chkmax(T &x, T y) { x = x >= y ? x : y; }
template <typename T> void chkmin(T &x, T y) { x = x <= y ? x : y; }
constexpr int N = 4e5 + 10;

constexpr Hash mod = {1000000009, 1000000007};

Hash operator*(const Hash &a, const Hash &b) {
  return {a.first * b.first % mod.first, a.second * b.second % mod.second};
}

Hash operator+(const Hash &a, const Hash &b) {
  return {(a.first + b.first) % mod.first, (a.second + b.second) % mod.second};
}

Hash operator-(const Hash &a, const Hash &b) {
  return {((a.first % mod.first) + mod.first - (b.first % mod.first)) % mod.first, 
          ((a.second % mod.second) + mod.second - (b.second % mod.second)) % mod.second};
}

Hash operator*(const Hash &a, const ull &b) {
  return {a.first * b % mod.first, a.second * b % mod.second};
}

void solve();
int main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);

  int T = 1; //cin >> T; 
  while(T -- ) solve();
  return 0;
}

void solve() {
  Hash base = {rng() % mod.first, rng() % mod.second};
  
  int n; cin >> n; 

  ll ans = 0;
  map<Hash, int> mp1, mp2; //mp1是完全一样的字符串个数 mp2是一个作为另一个的 子串,且另一个前缀后缀相同
  
  vector<Hash> hash(N + 1);  hash[0] = {1, 1};
  for(int i = 1; i <= N; i ++ ) hash[i] = hash[i - 1] * base;
  
  for(int i = 0; i < n; i ++ ) {
    string s; cin >> s;
    int m = s.size();
    vector<Hash> hashs(m + 1);
    hashs[0] = {1, 1};

    for(int i = 1; i <= m; i ++ ) { 
      hashs[i] = hashs[i - 1] * base + hashs[0] * (s[i - 1] - 'a' + 1);
    }

    Hash S = hashs[m] - hashs[0] * hash[m]; //整个字符串的哈希值要减去 hashs[0] * hash[m] /jk /jk
    
    ans += mp1[S]; //加上和当前串完全一样的字符串个数
    mp1[S] ++;  
    ans += mp2[S]; //子串
    for(int i = 1; i * 2 < m; i ++ ) {
      Hash x = hashs[i] - (hashs[0] * hash[i]);
      Hash y = hashs[m] - (hashs[m - i] * hash[i]);
      if(x == y) {
        Hash mid = hashs[m - i] - (hashs[i] * hash[m - i - i]);
        ans += mp1[mid];
        mp2[mid] ++;
      } 
    }
  }
  
  cout << ans;

}
posted @ 2022-05-10 11:11  ccz9729  阅读(54)  评论(0编辑  收藏  举报