hdu 2243(poj2778的加强版!(AC自动机+矩阵))

问你长度为1~N的串中包含了模式串的串总共有几个(先求出总共小于L的单词数(26^1+26^2+26^3+...26^L)..然后再减去不包括所给字符串的单词

答案要模2^64,直接用unsinged __int64!!!!

算法:AC自动机+二分求等比矩阵和+二分求等比数列和

(ps:

等比矩阵求和(或等比数列),有经典算法,假定原矩阵为A,阶数为n,那么构造一个阶数为2n的矩阵,如下
      | A   E |         其中O代表O矩阵,E代表单位矩阵,这样,求出的K次矩阵的右上n子矩阵正好是
      | O   E |         等比矩阵的K项和,这种构造法比我实现的两次二分快了4倍左右。

这里有个错误要注意...求出K次矩阵的右上子矩阵应该是K-1项的和...

 

//#pragma comment(linker, "/STACK:102400000")
#include<cstdlib>
#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<set>
#include<map>
#include<list>
#include<queue>
#include<stack>
#include<vector>
#define tree int o,int l,int r
#define lson o<<1,l,mid
#define rson o<<1|1,mid+1,r
#define lo o<<1
#define ro o<<1|1
#define pb push_back
#define mp make_pair
#define ULL unsigned long long
#define inf 0x7fffffff
#define eps 1e-7
#define N 35
#define M 26
using namespace std;
//void print2(ULL ma[][N],int n,int m)
//{
//    cout<<"=============="<<endl;
//    for (int i=0; i<n ; i++ )
//    {
//        printf("%d",ma[i][0]);
//        for (int j=1; j<m ; j++)
//            printf(" %d",ma[i][j]);
//        cout<<endl;
//    }
//    cout<<"-----------"<<endl;
//}

int m,n,T,t,x,y,u;
int ch[N][M];
int v[N];
int f[N],last[N],num;
ULL ma[N][N],temp[N][N],sum[N][N],sum2[N][N],ans;
ULL he,maxv;
void clear()//Trie树初始化
{
    num=1;
    memset(ma,0,sizeof(ma));
    memset(sum,0,sizeof(sum));
    memset(ch[0],0,sizeof(ch[0]));
    memset(v,0,sizeof(v));
    memset(last,0,sizeof(last));
}
int idx(char c)
{
    return c-'a';
}
void insert(char str[],int value)//建Trie树
{
    int len=strlen(str);
    int u=0;
    for (int i=0; i<len; ++i )
    {
        int c=idx(str[i]);
        if(!ch[u][c])//保存的是结点坐标
        {
            memset(ch[num],0,sizeof(ch[num]));
            ch[u][c]=num++;//
        }
        u=ch[u][c];
    }
    v[u]=value;
}
void getac()
{
    queue<int> q;//保存的节点下标
    f[0]=0;
    for (int c=0; c<M; ++c )
    {
        int u=ch[0][c];
        if(u)//不需要优化的else
        {
            q.push(u);
            f[u]=0;
            last[u]=v[u];//WA
        }
    }
    while(!q.empty())
    {
        int r=q.front();
        q.pop();
        for (int c=0; c<M; ++c )//注意:c表示节点值,不是结点位置
        {
            int u=ch[r][c];
            if(u)
            {
                q.push(u);
                int s=f[r];
//                while(s&&ch[s][c]==0)s=f[s];//可以简化
                f[u]=ch[s][c];
                last[u]=(v[u]||last[f[u]]);////
//                last[u]=v[f[u]]?f[u]:last[f[u]];
            }
            else //重要优化
                ch[r][c]=ch[f[r]][c];
        }
    }
}
char str[20];
void build()
{
    for(int i=0; i<num; i++)
    {
        for(int j=0; j<M; j++)
        {
            int u=ch[i][j];
            if(last[u]==0)
                ma[i][u]++;
        }
    }
}
void multi(ULL a[][N],ULL b[][N],int n)
{
    ULL c[N][N]= {0};
    for(int i=0; i<n; i++)
        for(int j=0; j<n; j++)
            for(int k=0; k<n; k++)
            {
                c[i][j]+=a[i][k]*b[k][j];
            }
    memcpy(a,c,sizeof(c));
}
void mat(int n,int num)
{
    if(n==1)
    {
        for(int i=0; i<num; i++)
            for(int j=0; j<num; j++)
                sum[i][j]=temp[i][j]=ma[i][j];
        he=maxv=26;
    }
    else
    {
        mat(n/2,num);
        memcpy(sum2,sum,sizeof(sum));
        multi(sum,temp,num);
        for(int i=0; i<num; i++)
            for(int j=0; j<num; j++)
                sum[i][j]+=sum2[i][j];

        multi(temp,temp,num);
        ULL hhh=he;
        he*=maxv;
        he+=hhh;
        maxv*=maxv;

        if(n&1)
        {
            maxv*=26;
            he+=maxv;
            multi(temp,ma,num);
            for(int i=0; i<num; i++)
                for(int j=0; j<num; j++)
                    sum[i][j]+=temp[i][j];
        }
    }
}
int main()
{
#ifndef ONLINE_JUDGE
    freopen("ex.in","r",stdin);
#endif
    int ncase=0;
    while(scanf("%d%d%*c",&m,&n)==2)
    {
        clear();
        while(m--)
        {
            scanf("%s",str);
            insert(str,1);
        }
        getac();
        build();
        ans=0;
        mat(n,num);
        for(int i=0; i<num; i++)
        {
            ans+=sum[0][i];
        }
        printf("%I64u\n",he-ans);//%I64u
    }
    return 0;
}
View Code

 

posted @ 2013-10-17 17:36  baoff  阅读(219)  评论(0编辑  收藏  举报