以字符集为位数的字符串hash——上海网络赛G

先预处理一个hash[a][b]:开头字符为a, 结尾字符是b,中间的字符hash值为hs的的hash表,保存的是出现次数

对于一个子串求hash值的策略:设hash值是个26位的数,每新增一个字符,就在对应位上+1

#include<bits/stdc++.h>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/hash_policy.hpp>
using namespace __gnu_pbds;
using namespace std;

typedef long long ll;
const int MAXN = 100010;
const int base = 13331;
const ll mod = 1e9 + 9;

int add(int a, int b){
    return a + b >= mod ? a + b - mod : a + b;
}

int sub(int a, int b){
    return a - b < 0 ? a - b + mod : a - b;
}

int mul(ll a, int b){
    return a * b >= mod ? a * b % mod : a * b;
}

int qpow(int a, int b){
    int ans = 1;
    for(; b; b >>= 1){
        if(b & 1) ans = mul(ans, a);
        a = mul(a, a);
    }
    return ans;
}

gp_hash_table<int, int> mp[26][26];
int n, m, cnt[MAXN];
char s[MAXN], t[MAXN];

struct Node{
    int hs;
    int len, st, ed;
}nd[MAXN];

int main(){
    int T, q;
    scanf("%d", &T);
    while(T --){
        for(int i = 0; i < 26; i ++) for(int j = 0; j < 26; j ++)
            mp[i][j].clear();
        scanf("%s", s);
        n = strlen(s);
        scanf("%d", &q);
        for(int i = 1; i <= q; i ++){
            scanf("%s", t);
            m = strlen(t);
            cnt[nd[i].len = m] ++;
            nd[i].st = t[0] - 'a';
            nd[i].ed = t[m - 1] - 'a';
            int tmp = 0;
            for(int j = 1; j < m - 1; j ++)
                tmp = add(tmp, qpow(base, t[j] - 'a'));
            mp[nd[i].st][nd[i].ed][nd[i].hs = tmp] = 1;
        }
        for(int i = 2; i <= n; i ++) if(cnt[i]){
            int tmp = 0;
            int st = s[0] - 'a';
            int ed = s[i - 1] - 'a';
            for(int j = 1; j < i - 1; j ++)
                tmp = add(tmp, qpow(base, s[j] - 'a'));
            if(mp[st][ed].find(tmp) != mp[st][ed].end())
                mp[st][ed][tmp] ++;
            for(int l = 1, r = i; r < n; l ++, r ++){
                st = s[l] - 'a';
                ed = s[r] - 'a';
                tmp = sub(tmp, qpow(base, s[l] - 'a'));
                tmp = add(tmp, qpow(base, s[r - 1] - 'a'));
                if(mp[st][ed].find(tmp) != mp[st][ed].end())
                    mp[st][ed][tmp] ++;
            }
        }
        for(int i = 1; i <= q; i ++){
            cnt[nd[i].len] --;
            printf("%d\n", mp[nd[i].st][nd[i].ed][nd[i].hs] - 1);
        }
    }
    return 0;
}

 

posted on 2019-10-18 17:27  zsben  阅读(184)  评论(0编辑  收藏  举报

导航