[CQOI2014]通配符匹配 题解

\(dp[i][j]\)表示文本串的前\(j\)个字符匹配了模式串第\(i\)个通配符(包括这个通配符)前面的所有字符,值为\(0\)代表不能,值为\(1\)代表可以

那么,显然有两种转移:第\(i+1\)个通配符是'\(*\)'或者'\(?\)'

转移的条件,是从第\(j+1\)个字符开始的一段字符串可以与第\(i\)个和第\(i+1\)个通配符之间的模板串字符匹配

设这一段模板串长度为\(k\)

如果是'\(?\)',那么\(dp[i+1][j+k+1]=1\)

如果是'\(*\)',那么\(dp[i+1][j+k...strlen(s)]=1\)

这个递推貌似是对的,但是有一个问题:

怎么足够快地知道,从第\(j+1\)个字符开始的一段字符串,与第\(i\)个和第\(i+1\)个通配符之间的模板串字符,可不可以匹配???

字符串Hash!!

#include <bits/stdc++.h>
#define ull signed long long
using namespace std ;
const int MAXN = 100000 + 5 ;
const ull Base = 19260817ull ;
char a[ MAXN ] , b[ 15 ][ MAXN ] ;
int len , dp[ 15 ][ MAXN ] ;
ull pw[ MAXN ] , pre[ MAXN ] , h[ 15 ] ;
int n , cnt , st[ 15 ] , sp[ 15 ] ;
inline ull Hash ( char s[] , int len ) {
    ull res = 0 ;
    for ( int i = 0 ; i < len ; i ++ ) res *= Base , res += (ull) s[ i ] ;
    return res ;
}
signed main () {
    // freopen ( "match3.in" , "r" , stdin ) ;
    scanf ( "%s" , a ) ;
    len = strlen ( a ) ; pw[ 0 ] = 1 ;
    for ( int i = 1 ; i <= MAXN - 5 ; i ++ ) pw[ i ] = pw[ i - 1 ] * Base ;
    if ( a[ 0 ] == '*' || a[ 0 ] == '?' ) h[ ++ cnt ] = Hash ( b[ cnt ] , st[ cnt ] = 0 ) ;
    for ( int i = 0 ; i < len ; ) {
        if ( a[ i ] == '*' || a[ i ] == '?' ) sp[ cnt ] = ( sp[ cnt ] || ( a[ i ] == '*' ) ) , i ++ ;
        int j = 0 ; cnt ++ ;
        while ( a[ i ] != '*' && a[ i ] != '?' && i < len ) b[ cnt ][ j ] = a[ i ] , i ++ , j ++ ;
        h[ cnt ] = Hash ( b[ cnt ] , st[ cnt ] = j ) ;
    }
    if ( a[ len - 1 ] == '*' || a[ len - 1 ] == '?' ) sp[ ++ cnt ] = h[ cnt ] = st[ cnt ] = 0 ;
    else sp[ cnt + 1 ] = 0 ;
    len ++ ;
    // for ( int i = 1 ; i <= cnt ; i ++ ) cout << st[ i ] << " " ; cout << endl ;
    scanf ( "%d" , &n ) ;
    while ( n -- ) {
        memset ( a , 0 , sizeof ( a ) ) ;
        scanf ( "%s" , a ) ; a[ strlen ( a ) ] = '$' ;
        len = strlen ( a ) ;
        memset ( dp , 0 , sizeof ( dp ) ) ; dp[ 0 ][ 0 ] = 1 ;
        for ( int i = 0 ; i < len ; i ++ ) pre[ i + 1 ] = pre[ i ] * Base + (ull) a[ i ] ;
        for ( int j = 0 ; j <= len ; j ++ ) {
            for ( int i = 0 ; i <= cnt ; i ++ ) {
                if ( !dp[ i ][ j ] ) continue ;
                ull t1 = pre[ j + st[ i + 1 ] ] - pre[ j ] * pw[ st[ i + 1 ] ] ;
                if ( t1 == h[ i + 1 ] ) {
                    if ( sp[ i + 1 ] ) for ( int k = j + st[ i + 1 ] ; k <= len ; k ++ ) dp[ i + 1 ][ k ] = 1 ;
                    else dp[ i + 1 ][ j + st[ i + 1 ] + 1 ] = 1 ;
                }
             }
        }
        if ( dp[ cnt ][ len ] ) printf ( "YES\n" ) ;
        else printf ( "NO\n" ) ;
    }
    return 0 ;
}

结果提交,90pts ,TLE

优化

我们发现程序中的

for ( int k = j + st[ i + 1 ] ; k <= len ; k ++ ) dp[ i + 1 ][ k ] = 1 ;

这个for循环时间复杂度实在太高

怎么办办?

这个\(for\)循环的作用,是在\(sp[i+1]=1\),也就是下一个通配符为'\(*\)'的时候,用来一路更新下去的

但是这样更新来更新去,一定会导致\(TLE\)

那我们需要一个优化,让这个循环的过程分散到遍历\(dp[i][j]\)的时候去,省去一层\(n\)的复杂度

这里,我们考虑使用不同的值来表示\(dp[i][j]\)的不同意义:

\(dp[i][j]=−1\)的时候,说明这个节点没有访问过,\(continue\)

\(dp[i][j]=0\)的时候,说明这个节点被且仅被一个'?'往后的递推访问过,这时我们令\(dp[i][j]=2,dp[i][j+1]=2\),并\(continue\)(因为当前节点并没有意义,只是访问过,不能继续递推)

\(dp[i][j]=1\)的时候,说明这个节点是被'\(*\)'访问过的,这时我们令\(dp[i][j+1]=1\),并且这个点有意义,可以往下递推

\(dp[i][j]=2\)的时候,说明这个节点被'\(?\)'访问过的节点更新到了\(2\),这时直接从这个节点往后递推,不需要更新值

最后,当\(dp[i][j]=3\)的时候——这个是一个非常特殊的情况

我们发现,上述的\(-1\)\(2\)的值里面,\(1\)的优先级最高,\(0\)次之,\(2\)最低,\(-1\)可以被它们随便覆盖

但是我们的确会出现这样的情况:一个\(0\)延伸出来的\(2\),覆盖到了另一个\(0\)

此时这个\(0\)不仅会令\(dp[i][j]=dp[i][j+1]=2\),它自身也需要往下递推,而不是直接\(continue\)(因为上一个过来的2说明它有这个意义)

所以我们令这种情况下的\(dp[i][j]\)的值为\(3\),此时令\(dp[i][j+1]=2\),并且从当前节点递推

初始化的时候,全部设为\(-1\)\(dp[0][0]=2\)

最后如果\(dp\)[模板串的段数][文本串长度]不是\(-1\)的话,就输出\(YES\),否则\(NO\)

#include <bits/stdc++.h>
#define ull signed long long
using namespace std ;
const int MAXN = 100000 + 5 ;
const ull Base = 19260817ull ;
char a[ MAXN ] , b[ 15 ][ MAXN ] ;
int len , dp[ 15 ][ MAXN ] ;
ull pw[ MAXN ] , pre[ MAXN ] , h[ 15 ] ;
int n , cnt , st[ 15 ] , sp[ 15 ] ;
inline ull Hash ( char s[] , int len ) {
    ull res = 0 ;
    for ( int i = 0 ; i < len ; i ++ ) res *= Base , res += (ull) s[ i ] ;
    return res ;
}
signed main () {
    // freopen ( "match3.in" , "r" , stdin ) ;
    scanf ( "%s" , a ) ;
    len = strlen ( a ) ; pw[ 0 ] = 1 ;
    for ( int i = 1 ; i <= MAXN - 5 ; i ++ ) pw[ i ] = pw[ i - 1 ] * Base ;
    if ( a[ 0 ] == '*' || a[ 0 ] == '?' ) h[ ++ cnt ] = Hash ( b[ cnt ] , st[ cnt ] = 0 ) ;
    for ( int i = 0 ; i < len ; ) {
        if ( a[ i ] == '*' || a[ i ] == '?' ) sp[ cnt ] = ( sp[ cnt ] || ( a[ i ] == '*' ) ) , i ++ ;
        int j = 0 ; cnt ++ ;
        while ( a[ i ] != '*' && a[ i ] != '?' && i < len ) b[ cnt ][ j ] = a[ i ] , i ++ , j ++ ;
        h[ cnt ] = Hash ( b[ cnt ] , st[ cnt ] = j ) ;
    }
    if ( a[ len - 1 ] == '*' || a[ len - 1 ] == '?' ) sp[ ++ cnt ] = h[ cnt ] = st[ cnt ] = 0 ;
    else sp[ cnt + 1 ] = 0 ;
    len ++ ;
    // for ( int i = 1 ; i <= cnt ; i ++ ) cout << st[ i ] << " " ; cout << endl ;
    scanf ( "%d" , &n ) ;
    while ( n -- ) {
        memset ( a , 0 , sizeof ( a ) ) ;
        scanf ( "%s" , a ) ; a[ strlen ( a ) ] = '$' ;
        len = strlen ( a ) ;
        memset ( dp , -1 , sizeof ( dp ) ) ; dp[ 0 ][ 0 ] = 2 ;
        for ( int i = 0 ; i < len ; i ++ ) pre[ i + 1 ] = pre[ i ] * Base + (ull) a[ i ] ;
        for ( int j = 0 ; j <= len ; j ++ ) {
            for ( int i = 0 ; i <= cnt ; i ++ ) {
                if ( dp[ i ][ j ] == -1 ) continue ;
                if ( dp[ i ][ j ] == 1 ) dp[ i ][ j + 1 ] = 1 ;
                if ( !dp[ i ][ j ] ) {
                    dp[ i ][ j ] = 2 ;
                    if ( dp[ i ][ j + 1 ] == -1 ) dp[ i ][ j + 1 ] = 2 ;
                    if ( dp[ i ][ j + 1 ] == 0 ) dp[ i ][ j + 1 ] = 3 ;
                    continue ;
                }
                if ( dp[ i ][ j ] == 3 ) {
                    dp[ i ][ j ] = 2 ;
                    if ( dp[ i ][ j + 1 ] == -1 ) dp[ i ][ j + 1 ] = 2 ;
                    if ( dp[ i ][ j + 1 ] == 0 ) dp[ i ][ j + 1 ] = 3 ;
                }
                ull t1 = pre[ j + st[ i + 1 ] ] - pre[ j ] * pw[ st[ i + 1 ] ] ;
                if ( t1 == h[ i + 1 ] ) dp[ i + 1 ][ j + st[ i + 1 ] ] = max ( dp[ i + 1 ][ j + st[ i + 1 ] ] , sp[ i + 1 ] ) ;
             }
        }
        if ( dp[ cnt ][ len ] != -1 ) printf ( "YES\n" ) ;
        else printf ( "NO\n" ) ;
    }
    return 0 ;
}

posted @ 2020-08-07 20:48  hulean  阅读(165)  评论(0编辑  收藏  举报