字符串

### Description

  数据范围:\(n<=6,|s_i|<=100,m<=500\)

  

Solution

​  场上不会在ac自动机上面跑dp的我大概失去了智商==果然字符串这块还是有点薄弱啊

  首先想一个比较好的计算方式:比较明白的一点是如果前\(m\)位确定了,后\(m\)位自然也就确定了,我们将满足条件的串分成三大类:

(1)匹配串都在前\(m\)位(前半段)

(2)匹配串都在后\(m\)位(后半段)

(3)匹配串跨\(m\)这个位置

​  第三类又可以再分两类:跨\(m\)这个位置的串在前半段的长度比较大、在后半段的长度比较大

​   

  首先考虑前两类怎么计算:其实只要把正串和翻转之后再\(01\)反转的串都丢到ac自动机里面然后跑dp就好了

  看到这个\(n\)这么小,大概差不多就是用来状压的了吧,于是粗暴地令\(f[i][j][k]\)表示确定了前\(i\)位,当前在\(j\)这个节点,当前已经包含的串状态为\(k\),然后直接\(O(m*\)自动机节点数\(*2^n)\)暴力dp就好了

  具体一点就是对于每个节点记录一个\(st[x]\)表示走到这个节点意味着包含了哪些字符串,预处理的时候从fail树上面从上往下传就好了(当然实现的时候并不用真的建出来,记录一下bfs序然后直接传就好了)

​  最后就是第三种情况,这个其实也比较好搞,对于每个自动机上的节点我们维护一个\(midst[x]\)表示这个节点作为新串中的第\(m\)位可以包含到哪些匹配串,我们枚举每个匹配串(包括反串)的每一位,如果这个位置可以作为满足条件的串的第\(m\)位的话(说白了就是可从这个位置切开满足反对称),并且在这里切开之后满足前半段的长度更长的话(因为枚举的字符串中既有正串也有反串,所以只要保证一种情况就可以将(3)中的两小类不重不漏地算进去了),我们将其加入对应的自动机节点的\(midst[x]\)里面去,然后同理这个\(midst\)也要下传,方式和上面的\(st\)一样

​  最后查答案的时候枚举\(f\)的后两维,如果说当前的状态\(k|midst[j]=\)满状态的话,就将\(f[m][j][k]\)加入答案中

  
  代码大概长这个样子

#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
using namespace std;
const int N=15,L=110,M=510,MOD=998244353;
char s[N][L];
int n,m,ans,all;
int St(int x){if (x>n) x-=n; return 1<<x-1;}
bool in(int st,int x){return st>>x-1&1;}
int mul(int x,int y){return 1LL*x*y%MOD;}
int plu(int x,int y){return 1LL*x+y-(1LL*x+y>=MOD?MOD:0);}
namespace Ac{/*{{{*/
    const int N=1210,C=2,ST=(1<<6)+10;
    queue<int> q;
    int ch[N][C],fail[N],st[N],lis[N],midst[N];
    int f[M][N][ST];
    int tot,rt;
    void init(){tot=0; rt=0;}
    void debug(){
        for (int i=rt;i<=tot;++i) printf("%d ",st[i]); printf("\n");
    }
    int newnode(){
        fail[++tot]=0; st[tot]=0;
        for (int i=0;i<C;++i) ch[tot][i]=0;
        return tot;
    }
    void insert(int id){
        int now=rt,c,len=strlen(s[id]);
        for (int i=0;i<len;++i){
            c=s[id][i]-'0';
            if (!ch[now][c]) ch[now][c]=newnode();
            now=ch[now][c];
        }
        st[now]|=St(id);
    }
    void build(){
        int u,v;
        while (!q.empty()) q.pop();
        q.push(rt); lis[0]=0;
        while (!q.empty()){
            v=q.front(); q.pop(); lis[++lis[0]]=v;
            for (int i=0;i<C;++i){
                if (!ch[v][i]){
                    ch[v][i]=ch[fail[v]][i];
                    continue;
                }
                if (v==rt)
                    fail[ch[v][i]]=rt;
                else
                    fail[ch[v][i]]=ch[fail[v]][i];
                q.push(ch[v][i]);
            }
        }
        for (int i=1;i<=lis[0];++i) st[lis[i]]|=st[fail[lis[i]]];
    }
    void dp(){
        int u;
        f[0][rt][0]=1;
        for (int i=0;i<m;++i){
            for (int j=rt;j<=tot;++j)
                for (int stt=0;stt<=all;++stt){
                    if (f[i][j][stt]==0) continue;
                    for (int k=0;k<C;++k){
                        u=ch[j][k];
                        f[i+1][u][stt|st[u]]=plu(f[i+1][u][stt|st[u]],f[i][j][stt]);
                    }
                }
        }
    }
    bool check(int which,int mid){
        int tot1=mid,tot2=mid+1,len=strlen(s[which]);
        while (tot1>=0&&tot2<len){
            if (s[which][tot1]==s[which][tot2]) return 0;
            --tot1; ++tot2;
        }
        return 1;
    }
    void calc_mid(){
        int len,now,c;
        for (int i=1;i<=n*2;++i){
            len=strlen(s[i]);
            now=rt;
            for (int j=0;j<len-1;++j){
                c=s[i][j]-'0';
                if (check(i,j)&&(j+1)*2>=len)
                    midst[ch[now][c]]|=St(i);
                now=ch[now][c];
            }
        }
        for (int i=1;i<=lis[0];++i)
            midst[lis[i]]|=midst[fail[lis[i]]];
    }
    void solve(){
        build();
        dp();
        calc_mid();
        ans=0;
        for (int i=rt;i<=tot;++i){
            for (int stt=0;stt<=all;++stt){
                if ((midst[i]|stt)==all)
                    ans=plu(ans,f[m][i][stt]);
            }
        }
        printf("%d\n",ans);
    }
}/*}}}*/
 
int main(){
#ifndef ONLINE_JUDGE
    freopen("a.in","r",stdin);
#endif
    scanf("%d%d",&n,&m);
    int len;
    Ac::init();
    all=1<<n; --all;
    for (int i=1;i<=n;++i){
        scanf("%s",s[i]);
        len=strlen(s[i]);
        for (int j=0;j<len;++j)
            s[n+i][len-1-j]='0'+((s[i][j]-'0')^1);
        Ac::insert(i);
        Ac::insert(n+i);
    }
    Ac::solve();
}
posted @ 2018-12-05 21:00  yoyoball  阅读(212)  评论(0编辑  收藏  举报