牛客14894 最长回文

题面

 

题目链接

 

https://ac.nowcoder.com/acm/problem/14894

 

题目大意

 

有两个长度均为n的字符串A和B。

可以从A中选一个可以为空的子串A[l1..r1],B中选一个可以为空的子串B[l2..r2]

需要满足r1 = l2,然后把它们拼起来(A[l1..r1]+B[l2..r2])

求用这样的方法能得到的最长回文串 S 的长度。

 

解题思路

 

首先对 A , B 串都跑一边 Manacher,分别得到 PA , PB 和处理过的 A , B

然后再枚举处理过的 A / B,以每个字符作为 S 的中心点,半径为 max(PA , PB) 进行拓展即可

因为以 i 为中心 PA / PB 为半径的回文串中,它的最小回文长度就达到了 PA / PB - 1

所以只要从 PA / PB 两端拓展看还能找到多少可以相匹配的字符就可以了

而假设 S 的中心点为 i , A 提供的是 i - 1 , i - 2 ... ,B 提供的是 i + 1 , i + 2 ...,他们之间相差了 2

所以当枚举到 A 的第 i 个字符时,需要操作的是 PA[ i ] 和 PB[ i - 2 ]

 

 

而拓展的方法有两种,一种是逐一匹配,俗称 brute force

另一种是 二分拓展的长度 + hash check  来匹配

第一种做法的复杂度我不太会算,感觉会超时但才跑了500ms?

而第二种做法显然要快上不少,大概是 50ms

这里提供两种做法

 

AC_Coder_(暴力)

#include<bits/stdc++.h>
using namespace std;
const int N = 3e5 + 10;
string a , b;
int pa[N] , pb[N] , res = 1;
string Manacher(string a , int *p)
{
    string t = "$#";
    for(auto i : a) t += i , t += '#';
    int mx = 0 , id = 0 ;
    int len = t.size() , ans = 0; 
    for(int i = 1 ; i < len ; i ++)
    {
        p[i] = mx > i ? min(p[2 * id - i] , mx - i) : 1;
        while(t[i + p[i]] == t[i - p[i]]) p[i] ++ ;
        if(mx < i + p[i]) mx = i + p[i] , id = i;
        ans = max(ans , p[i] - 1);
    }
    res = max(res , ans);
    return t;
}
signed main()
{
    int n ;
    cin >> n >> a >> b;
    a = Manacher(a , pa) , b = Manacher(b , pb);
    n = n * 2 + 2;
    int ans = 1;
    for(int i = 2 ; i <= n ; i ++)
    {
        int len = max(pa[i] , pb[i - 2]);
        while(a[i - len] == b[i - 2 + len]) len ++; 
        ans = max(ans , len - 1); 
    }
    cout << ans << '\n';
    return 0;
}

 

AC_Coder_(hash + 二分)

#include<bits/stdc++.h>
#define int long long
#define ull unsigned long long
using namespace std;
const int N = 3e5 + 10;
const int MOD = 999998639; 
const int P = 13331;
string a , b;
int pa[N] , pb[N] , res = 1;
ull pre[N] , suf[N] , power[N];
string Manacher(string a , int *p)
{
    string t = "$#";
    for(auto i : a) t += i , t += '#';
    int mx = 0 , id = 0 ;
    int len = t.size() , ans = 0; 
    for(int i = 1 ; i < len ; i ++)
    {
        p[i] = mx > i ? min(p[2 * id - i] , mx - i) : 1;
        while(t[i + p[i]] == t[i - p[i]]) p[i] ++ ;
        if(mx < i + p[i]) mx = i + p[i] , id = i;
        ans = max(ans , p[i] - 1);
    }
    res = max(res , ans);
    return t;
}
ull get_hash1(int l , int r)
{
    if(l > r) return -999;
    return (pre[r] - pre[l - 1] * power[r - l + 1] % MOD + MOD) % MOD;
}
ull get_hash2(int l , int r)
{
    if(l > r) return -888;
    return (suf[l] - suf[r + 1] * power[r - l + 1] % MOD + MOD) % MOD;
}
void init(int n)
{
    power[0] = 1;
    for(int i = 1 ; i < N - 5 ; i ++) power[i] = power[i - 1] * P % MOD;
    pre[0] = a[0] ;
    for(int i = 1 ; i < n ; i ++) pre[i] = (pre[i - 1] * P + a[i]) % MOD ;
    suf[n - 1] = b[n - 1];
    for(int i = n - 2 ; i >= 0 ; i --) suf[i] = (suf[i + 1] * P + b[i]) % MOD; 
}
signed main()
{
    int n ;
    cin >> n >> a >> b;
    a = Manacher(a , pa) , b = Manacher(b , pb);
    n = n * 2 + 2;
    init(n);
    int ans = 1;
    for(int i = 2 ; i <= n ; i ++)
    {
        int add = 0;
        int len = max(pa[i] , pb[i - 2]);
        int sa = i - len + 1  , sb = i - 2 + len - 1;
        int l = 0 , r = min(sa - 1 , n - sb);
        while(l <= r)
        {
            int mid = l + r >> 1;
            int l1 = max(0LL , sa - mid) , r1 = min(n , sa - 1);
            int l2 = max(sb + 1 , 0LL) , r2 = min(n , sb + mid);
            if(get_hash1(l1 , r1) == get_hash2(l2 , r2)) 
            l = mid + 1 , add = mid;
            else r = mid - 1; 
        }
        ans = max(ans , add + len - 1);
    }
    cout << ans << '\n';
    return 0;
}

 

posted @ 2020-04-30 17:00  GsjzTle  阅读(345)  评论(1编辑  收藏  举报