manacher

manacher

  • 用途 : 找字符串中的最长的回文子串。

考虑该问题【模板】manacher 求最长回文串长度。该如何做?

  • 暴力 \(O(n^2)\)

    就是枚举回文中心,向外拓展。代码太简单了,就不挂了。其实是懒得打

  • 二分+hash \(O(n\log n)\)

    将字符串正向hash,反向hash,枚举回文中心,二分答案即可。

    足以通过本题了

    这代码跑的比一些实现较劣的正解快

    点此查看代码
    #include<bits/stdc++.h>
    #include<bits/extc++.h>
    // using namespace __gnu_pbds;
    // using namespace __gnu_cxx;
    using namespace std;
    #define infile(x) freopen(x,"r",stdin)
    #define outfile(x) freopen(x,"w",stdout)
    #define errfile(x) freopen(x,"w",stderr)
    using ll=long long;using ull=unsigned long long;
    #ifdef LOCAL
        FILE *InFile = infile("in.in"),*OutFile = outfile("out.out");
        // FILE *ErrFile=errfile("err.err");
    #else
        FILE *Infile = stdin,*OutFile = stdout;
        //FILE *ErrFile = stderr;
    #endif
    const int N = 1.1e7 + 10,base = 131;
    char s[N];
    ull pw[N],has1[N],has2[N];
    int n,ans = 0;
    inline ull get_hash1(int l,int r){return has1[r] - has1[l-1]*pw[r-l+1];}
    inline ull get_hash2(int l,int r){return has2[l] - has2[r+1]*pw[r-l+1];}
    inline bool check1(int mid,int pos){
        return get_hash1(pos-mid,pos-1) == get_hash2(pos,pos+mid-1);
    }
    inline bool check2(int mid,int pos){
        return get_hash1(pos-mid,pos-1) == get_hash2(pos+1,pos+mid);
    }
    inline void binary1(int i){
        int r = min(i-1,n-i),l = max(ans/2,1),res = 0;
        while(l <= r){
            int mid = (l+r)>>1;
            if(get_hash1(i-mid,i-1) == get_hash2(i+1,i+mid)) res = mid,l = mid+1;
            else r = mid-1;
        }
        ans = max(ans,res * 2 + 1);
    }
    inline void binary2(int i){
        int r = min(i-1,n-i) + 1,l = max(ans/2,1),res = 0;
        while(l <= r){
            int mid = (l+r)>>1;
            if(get_hash1(i-mid,i-1) == get_hash2(i,i+mid-1)) res = mid,l = mid+1;
            else r = mid-1;
        }
        ans = max(ans,res*2);
    }
    signed main(){
        cin.tie(nullptr)->sync_with_stdio(false);
        cout.tie(nullptr)->sync_with_stdio(false);
        cin >> (s+1);
        n = strlen(s+1);
        pw[0] = 1;
        ans = 1;
        for(int i = 1;i <= n; ++i) pw[i] = pw[i-1]*base;
        for(int i = 1;i <= n; ++i) has1[i] = has1[i-1]*base+s[i];
        for(int i = n;i >= 1; --i) has2[i] = has2[i+1]*base+s[i];
        for(int i = 1;i <= n; ++i){
            if(s[i-1] == s[i+1]) binary1(i);
            if(s[i] == s[i-1]) binary2(i);
        }
        cout<<ans;
    }
    
  • hash \(O(n)\) 我不会。挂个oi-wiki的讲解

我们发现,以上这些方法,要么时间复杂度超标,要么代码难度和细节上超标。

那么有没有一种方法可以既不会T,也非常好写的算法呢?

  • manacher \(O(n)\)

    优点 : 常数小,代码简洁。

    • 实现方法 : 考虑优化最上面的内个暴力。

      枚举回文中心有两种情况:

      1. 形如\(aba\)的长度为奇数的回文串,回文中心是它最中间的字符。
      2. 形如\(abba\)的长度为偶数的回文串,回文中心是两个字符最中间的字符中间的空隙

      这样就要分类讨论,非常不符合manacher简洁的特点。

      所以manacher对其的优化就是在相邻两个字符中间和字符串首尾(为了方便处理)插入一个不会被用到的字符(通常是$,#,@等,这里我选择使用#)

      那么形如\(aba\)的回文串就会变成 #a#b#a#,没有影响。

      形如\(abba\)的回文串就会变成 #a#b#b#a#,我们枚举回文中心至第二个#时就会将\(abba\)计入

      注意我们要在字符串最前面和字符串最后面插入两个不同的字符,使其变成形如 $#a#b#n#j#k#@的形式,防止越界

      我们发现,暴力较慢的原因是因为它重复遍历了子串。

      如何将这一部分优化掉

      我们记录一个变量r为已经触及到的最右边的字符的位置。

      再记录一个变量mid表示回文串中包含r的回文串的回文中心的位置。

      image

      用一个数组\(len_i\)表示第i个点能够扩展出的回文串长度。

      \(mid<i<r\)时,设\(i\)关于\(mid\)的对称点为\(j\),我们可以证明一定有\([j-len_j,j+len_j] = [i-len_i,i+len_i]\)

      注意当\(i+len_i>r\)时,无法保证上述情况相等。

      对于该情况,我们直接令\(len_i=\min(len_j,r-i+1)\),暴力扩展。

      反之,若\(i>r\),此时只能暴力扩展。

      点此查看代码
      #include<bits/stdc++.h>
      #include<bits/extc++.h>
      // using namespace __gnu_pbds;
      // using namespace __gnu_cxx;
      using namespace std;
      using ll=long long;using ull=unsigned long long;
      using db = double;using ldb = long double;
      #define infile(x) freopen(x,"r",stdin)
      #define outfile(x) freopen(x,"w",stdout)
      #define errfile(x) freopen(x,"w",stderr)
      #ifdef LOCAL
          FILE *InFile = infile("in.in"),*OutFile = outfile("out.out");
          // FILE *ErrFile=errfile("err.err");
      #else
          FILE *Infile = stdin,*OutFile = stdout;
          //FILE *ErrFile = stderr;
      #endif
      const int N = 1.1e7 + 100;
      int n,tot,len[N<<1],mid,mxr,ans = 0;//注意len数组要开二倍空间
      char s1[N],s[N<<1];
      signed main(){
          cin.tie(nullptr)->sync_with_stdio(false);
          cout.tie(nullptr)->sync_with_stdio(false);
          cin>>(s1+1);
          n = strlen(s1+1);
          s[++tot] = '$';s[++tot] = '#';
          for(int i = 1;i <= n; ++i) s[++tot] = s1[i],s[++tot] = '#';
          s[++tot] = '@';s[++tot] = '\0';
          n = strlen(s+1);
          for(int i = 2;i < n; ++i){
              if(i < mxr) len[i] = min(len[mid*2-i],mxr-i+1);
              else len[i] = 1;
              while(s[i - len[i]] == s[i + len[i]]) ++len[i];
              if(i + len[i] - 1 > mxr) mxr = i + len[i] - 1,mid = i;
              ans = max(ans,len[i]);
          }
          cout<<ans-1;//千万注意这个
      }
      

      复杂度的话,感性理解一下吧,就是mid和r都单调不降,所以复杂度为线性。

例题 :

  1. [国家集训队] 拉拉队排练

    用桶存一下即可,其它板子。

    点此查看代码
    #include<bits/stdc++.h>
    #include<bits/extc++.h>
    // using namespace __gnu_pbds;
    // using namespace __gnu_cxx;
    using namespace std;
    using ll=long long;using ull=unsigned long long;
    using db = double;using ldb = long double;
    #define infile(x) freopen(x,"r",stdin)
    #define outfile(x) freopen(x,"w",stdout)
    #define errfile(x) freopen(x,"w",stderr)
    #ifdef LOCAL
        FILE *InFile = infile("in.in"),*OutFile = outfile("out.out");
        FILE *ErrFile=errfile("err.err");
    #else
        FILE *Infile = stdin,*OutFile = stdout;
        //FILE *ErrFile = stderr;
    #endif
    const int N = 2e6 + 10,mod = 19930726;
    char s1[N],s[N<<1];
    int len[N<<1],n,mid,mxr,tot,num[N<<1];
    ll k;
    inline int power(int a,int b,int mod){
        int res = 1;
        for(; b;b >>= 1,a = 1ll * a * a % mod)
            if(b&1)res = 1ll * res * a % mod;
        return res;
    }
    signed main(){
        cin.tie(nullptr)->sync_with_stdio(false);
        cout.tie(nullptr)->sync_with_stdio(false);
        cin>>n>>k>>(s1+1);
        s[++tot] = '$',s[++tot] = '#';
        for(int i = 1;i <= n; ++i) s[++tot] = s1[i],s[++tot] = '#';
        s[++tot] = '@',s[++tot] = '\0';
        n = strlen(s+1);
        for(int i = 2;i < n; ++i){
            if(i < mxr) len[i] = min(len[mid*2-i],mxr-i+1);
            else len[i] = 1;
            while(s[i - len[i]] == s[i + len[i]]){
                ++len[i];
            }
            if(i + len[i] - 1 > mxr) mxr = i + len[i] - 1,mid = i;
            if((len[i] - 1) & 1) num[len[i]-1]++;
        }
        ll res = 1,sum = 0;
        for(int i = n;i >= 1; --i){
            if(!(i&1)) continue;
            sum += num[i];
            if(sum > k) return cout<<(res = 1ll * res * power(i,k,mod) % mod)<<'\n',0;
            else res = 1ll * res * power(i,sum,mod) % mod,k -= sum;
        }
        cout<<(k?-1:res);
    }
    
  2. 高手过愚人节

    模板题,就是练一下手感。

    点此查看代码
    #include<bits/stdc++.h>
    #include<bits/extc++.h>
    // using namespace __gnu_pbds;
    // using namespace __gnu_cxx;
    using namespace std;
    #define infile(x) freopen(x,"r",stdin)
    #define outfile(x) freopen(x,"w",stdout)
    #define errfile(x) freopen(x,"w",stderr)
    #ifdef LOCAL
        FILE *InFile = infile("in.in"),*OutFile = outfile("out.out");
        // FILE *ErrFile=errfile("err.err");
    #else
        FILE *Infile = stdin,*OutFile = stdout;
        //FILE *ErrFile = stderr;
    #endif
    bool StdIn = cin.tie(nullptr)->sync_with_stdio(false);
    bool StdOut = cout.tie(nullptr)->sync_with_stdio(false);
    using ll=long long;using ull=unsigned long long;
    using db = double;using ldb = long double;
    const int N = 1e7 + 10;
    int n,tot,len[N<<1],mid,mxr,ans = 0;
    char s1[N],s[N<<1];
    inline void manacher(){
        cin>>(s1+1);n = strlen(s1+1);
        tot = mid = mxr = ans = 0;
        s[++tot] = '$';s[++tot] = '#';
        for(int i = 1;i <= n; ++i) s[++tot] = s1[i],s[++tot] = '#';
        s[++tot] = '@';s[++tot] = '\0';
        n = strlen(s+1);
        for(int i = 2;i < n; ++i){
            if(i < mxr) len[i] = min(len[mid*2-i],mxr-i+1);
            else len[i] = 1;
            while(s[i - len[i]] == s[i + len[i]]) ++len[i];
            if(i + len[i] - 1 > mxr) mxr = i + len[i] - 1,mid = i;
            ans = max(ans,len[i] - 1);
        }
        cout<<ans<<'\n';
    }
    signed main(){
        int T;cin>>T;
        while(T--) manacher();
    }
    
  3. [国家集训队] 最长双回文串

    枚举分割点应该是非常容易想到的。

    我们记录一下以该点为左端点的最长回文串长度,以该点为右端点的最长回文串长度。

    但我们发现manacher无法求出每个点这个值。

    0|1|2|3|4|5|6|7|8|9|10|11|12|13|14|15|16|17

    $|#|a|#|b|#|a|#|b|#|a |# |c |# |c |# |d |@

    比如这个,我们发现\(l_3\)求出来是0,但很明显,\(l_3\)应该是3,因为bab为一个回文串。

    因为我们在\(i=6\)时,\(len_i=6\),所以\(l_3\)没有更新,只更新到了\(l_1\),所以我们要将\(l_3\)通过\(l_1\)更新出来

    所以我们再一次递推,将\(l_i\)\(l_{i-2}-2\)取max

    同理,\(right_i\)\(right_{i+2}-2\)取max

    点此查看代码
    #include<bits/stdc++.h>
    #include<bits/extc++.h>
    // using namespace __gnu_pbds;
    // using namespace __gnu_cxx;
    using namespace std;
    #define infile(x) freopen(x,"r",stdin)
    #define outfile(x) freopen(x,"w",stdout)
    #define errfile(x) freopen(x,"w",stderr)
    #ifdef LOCAL
        FILE *InFile = infile("in.in"),*OutFile = outfile("out.out");
        // FILE *ErrFile=errfile("err.err");
    #else
        FILE *Infile = stdin,*OutFile = stdout;
        //FILE *ErrFile = stderr;
    #endif
    bool StdIn = cin.tie(nullptr)->sync_with_stdio(false);
    bool StdOut = cout.tie(nullptr)->sync_with_stdio(false);
    using ll=long long;using ull=unsigned long long;
    using db = double;using ldb = long double;
    const int N = 2e5 + 10;
    int lft[N],rht[N],len[N],n,tot,mid,mxr;
    char s1[N],s[N];
    signed main(){
        cin>>(s1+1);
        n = strlen(s1+1);
        s[++tot] = '$',s[++tot] = '#';
        for(int i = 1;i <= n; ++i) s[++tot] = s1[i],s[++tot] = '#';
        s[++tot] = '@';s[++tot] = '\0';
        n = strlen(s+1);
        for(int i = 2;i < n; ++i){
            if(i < mxr) len[i] = min(len[mid*2-i],mxr-i+1);
            else len[i] = 1;
            while(s[i - len[i]] == s[i + len[i]]) len[i]++;
            if(i + len[i] - 1 > mxr) mxr = i + len[i] - 1,mid = i;
            lft[i + len[i] - 1] = max(len[i] - 1,lft[i + len[i] - 1]),
            rht[i - len[i] + 1] = max(len[i] - 1,rht[i - len[i] + 1]);
        }
        for(int i = 2;i <= n - 1;i += 2) rht[i] = max(rht[i],rht[i + 2] - 2);
        for(int i = n - 1;i >= 2;i -= 2) lft[i] = max(lft[i],lft[i - 2] - 2);
        int ans = 0;
        for(int i = 2;i < n;i += 2) if(lft[i] && rht[i]) ans = max(ans,lft[i] + rht[i]);
        cout<<ans;
    }
    
posted @ 2024-07-30 17:36  CuFeO4  阅读(6)  评论(0编辑  收藏  举报