[CQOI2014]通配符匹配 题解
设\(dp[i][j]\)表示文本串的前\(j\)个字符匹配了模式串第\(i\)个通配符(包括这个通配符)前面的所有字符,值为\(0\)代表不能,值为\(1\)代表可以
那么,显然有两种转移:第\(i+1\)个通配符是'\(*\)'或者'\(?\)'
转移的条件,是从第\(j+1\)个字符开始的一段字符串可以与第\(i\)个和第\(i+1\)个通配符之间的模板串字符匹配
设这一段模板串长度为\(k\)
如果是'\(?\)',那么\(dp[i+1][j+k+1]=1\)
如果是'\(*\)',那么\(dp[i+1][j+k...strlen(s)]=1\)
这个递推貌似是对的,但是有一个问题:
怎么足够快地知道,从第\(j+1\)个字符开始的一段字符串,与第\(i\)个和第\(i+1\)个通配符之间的模板串字符,可不可以匹配???
字符串Hash!!
#include <bits/stdc++.h>
#define ull signed long long
using namespace std ;
const int MAXN = 100000 + 5 ;
const ull Base = 19260817ull ;
char a[ MAXN ] , b[ 15 ][ MAXN ] ;
int len , dp[ 15 ][ MAXN ] ;
ull pw[ MAXN ] , pre[ MAXN ] , h[ 15 ] ;
int n , cnt , st[ 15 ] , sp[ 15 ] ;
inline ull Hash ( char s[] , int len ) {
ull res = 0 ;
for ( int i = 0 ; i < len ; i ++ ) res *= Base , res += (ull) s[ i ] ;
return res ;
}
signed main () {
// freopen ( "match3.in" , "r" , stdin ) ;
scanf ( "%s" , a ) ;
len = strlen ( a ) ; pw[ 0 ] = 1 ;
for ( int i = 1 ; i <= MAXN - 5 ; i ++ ) pw[ i ] = pw[ i - 1 ] * Base ;
if ( a[ 0 ] == '*' || a[ 0 ] == '?' ) h[ ++ cnt ] = Hash ( b[ cnt ] , st[ cnt ] = 0 ) ;
for ( int i = 0 ; i < len ; ) {
if ( a[ i ] == '*' || a[ i ] == '?' ) sp[ cnt ] = ( sp[ cnt ] || ( a[ i ] == '*' ) ) , i ++ ;
int j = 0 ; cnt ++ ;
while ( a[ i ] != '*' && a[ i ] != '?' && i < len ) b[ cnt ][ j ] = a[ i ] , i ++ , j ++ ;
h[ cnt ] = Hash ( b[ cnt ] , st[ cnt ] = j ) ;
}
if ( a[ len - 1 ] == '*' || a[ len - 1 ] == '?' ) sp[ ++ cnt ] = h[ cnt ] = st[ cnt ] = 0 ;
else sp[ cnt + 1 ] = 0 ;
len ++ ;
// for ( int i = 1 ; i <= cnt ; i ++ ) cout << st[ i ] << " " ; cout << endl ;
scanf ( "%d" , &n ) ;
while ( n -- ) {
memset ( a , 0 , sizeof ( a ) ) ;
scanf ( "%s" , a ) ; a[ strlen ( a ) ] = '$' ;
len = strlen ( a ) ;
memset ( dp , 0 , sizeof ( dp ) ) ; dp[ 0 ][ 0 ] = 1 ;
for ( int i = 0 ; i < len ; i ++ ) pre[ i + 1 ] = pre[ i ] * Base + (ull) a[ i ] ;
for ( int j = 0 ; j <= len ; j ++ ) {
for ( int i = 0 ; i <= cnt ; i ++ ) {
if ( !dp[ i ][ j ] ) continue ;
ull t1 = pre[ j + st[ i + 1 ] ] - pre[ j ] * pw[ st[ i + 1 ] ] ;
if ( t1 == h[ i + 1 ] ) {
if ( sp[ i + 1 ] ) for ( int k = j + st[ i + 1 ] ; k <= len ; k ++ ) dp[ i + 1 ][ k ] = 1 ;
else dp[ i + 1 ][ j + st[ i + 1 ] + 1 ] = 1 ;
}
}
}
if ( dp[ cnt ][ len ] ) printf ( "YES\n" ) ;
else printf ( "NO\n" ) ;
}
return 0 ;
}
结果提交,90pts ,TLE
优化
我们发现程序中的
for ( int k = j + st[ i + 1 ] ; k <= len ; k ++ ) dp[ i + 1 ][ k ] = 1 ;
这个for
循环时间复杂度实在太高
怎么办办?
这个\(for\)循环的作用,是在\(sp[i+1]=1\),也就是下一个通配符为'\(*\)'的时候,用来一路更新下去的
但是这样更新来更新去,一定会导致\(TLE\)
那我们需要一个优化,让这个循环的过程分散到遍历\(dp[i][j]\)的时候去,省去一层\(n\)的复杂度
这里,我们考虑使用不同的值来表示\(dp[i][j]\)的不同意义:
当\(dp[i][j]=−1\)的时候,说明这个节点没有访问过,\(continue\)
当\(dp[i][j]=0\)的时候,说明这个节点被且仅被一个'?'往后的递推访问过,这时我们令\(dp[i][j]=2,dp[i][j+1]=2\),并\(continue\)(因为当前节点并没有意义,只是访问过,不能继续递推)
当\(dp[i][j]=1\)的时候,说明这个节点是被'\(*\)'访问过的,这时我们令\(dp[i][j+1]=1\),并且这个点有意义,可以往下递推
当\(dp[i][j]=2\)的时候,说明这个节点被'\(?\)'访问过的节点更新到了\(2\),这时直接从这个节点往后递推,不需要更新值
最后,当\(dp[i][j]=3\)的时候——这个是一个非常特殊的情况
我们发现,上述的\(-1\)到\(2\)的值里面,\(1\)的优先级最高,\(0\)次之,\(2\)最低,\(-1\)可以被它们随便覆盖
但是我们的确会出现这样的情况:一个\(0\)延伸出来的\(2\),覆盖到了另一个\(0\)
此时这个\(0\)不仅会令\(dp[i][j]=dp[i][j+1]=2\),它自身也需要往下递推,而不是直接\(continue\)(因为上一个过来的2说明它有这个意义)
所以我们令这种情况下的\(dp[i][j]\)的值为\(3\),此时令\(dp[i][j+1]=2\),并且从当前节点递推
初始化的时候,全部设为\(-1\),\(dp[0][0]=2\)
最后如果\(dp\)[模板串的段数][文本串长度]不是\(-1\)的话,就输出\(YES\),否则\(NO\)
#include <bits/stdc++.h>
#define ull signed long long
using namespace std ;
const int MAXN = 100000 + 5 ;
const ull Base = 19260817ull ;
char a[ MAXN ] , b[ 15 ][ MAXN ] ;
int len , dp[ 15 ][ MAXN ] ;
ull pw[ MAXN ] , pre[ MAXN ] , h[ 15 ] ;
int n , cnt , st[ 15 ] , sp[ 15 ] ;
inline ull Hash ( char s[] , int len ) {
ull res = 0 ;
for ( int i = 0 ; i < len ; i ++ ) res *= Base , res += (ull) s[ i ] ;
return res ;
}
signed main () {
// freopen ( "match3.in" , "r" , stdin ) ;
scanf ( "%s" , a ) ;
len = strlen ( a ) ; pw[ 0 ] = 1 ;
for ( int i = 1 ; i <= MAXN - 5 ; i ++ ) pw[ i ] = pw[ i - 1 ] * Base ;
if ( a[ 0 ] == '*' || a[ 0 ] == '?' ) h[ ++ cnt ] = Hash ( b[ cnt ] , st[ cnt ] = 0 ) ;
for ( int i = 0 ; i < len ; ) {
if ( a[ i ] == '*' || a[ i ] == '?' ) sp[ cnt ] = ( sp[ cnt ] || ( a[ i ] == '*' ) ) , i ++ ;
int j = 0 ; cnt ++ ;
while ( a[ i ] != '*' && a[ i ] != '?' && i < len ) b[ cnt ][ j ] = a[ i ] , i ++ , j ++ ;
h[ cnt ] = Hash ( b[ cnt ] , st[ cnt ] = j ) ;
}
if ( a[ len - 1 ] == '*' || a[ len - 1 ] == '?' ) sp[ ++ cnt ] = h[ cnt ] = st[ cnt ] = 0 ;
else sp[ cnt + 1 ] = 0 ;
len ++ ;
// for ( int i = 1 ; i <= cnt ; i ++ ) cout << st[ i ] << " " ; cout << endl ;
scanf ( "%d" , &n ) ;
while ( n -- ) {
memset ( a , 0 , sizeof ( a ) ) ;
scanf ( "%s" , a ) ; a[ strlen ( a ) ] = '$' ;
len = strlen ( a ) ;
memset ( dp , -1 , sizeof ( dp ) ) ; dp[ 0 ][ 0 ] = 2 ;
for ( int i = 0 ; i < len ; i ++ ) pre[ i + 1 ] = pre[ i ] * Base + (ull) a[ i ] ;
for ( int j = 0 ; j <= len ; j ++ ) {
for ( int i = 0 ; i <= cnt ; i ++ ) {
if ( dp[ i ][ j ] == -1 ) continue ;
if ( dp[ i ][ j ] == 1 ) dp[ i ][ j + 1 ] = 1 ;
if ( !dp[ i ][ j ] ) {
dp[ i ][ j ] = 2 ;
if ( dp[ i ][ j + 1 ] == -1 ) dp[ i ][ j + 1 ] = 2 ;
if ( dp[ i ][ j + 1 ] == 0 ) dp[ i ][ j + 1 ] = 3 ;
continue ;
}
if ( dp[ i ][ j ] == 3 ) {
dp[ i ][ j ] = 2 ;
if ( dp[ i ][ j + 1 ] == -1 ) dp[ i ][ j + 1 ] = 2 ;
if ( dp[ i ][ j + 1 ] == 0 ) dp[ i ][ j + 1 ] = 3 ;
}
ull t1 = pre[ j + st[ i + 1 ] ] - pre[ j ] * pw[ st[ i + 1 ] ] ;
if ( t1 == h[ i + 1 ] ) dp[ i + 1 ][ j + st[ i + 1 ] ] = max ( dp[ i + 1 ][ j + st[ i + 1 ] ] , sp[ i + 1 ] ) ;
}
}
if ( dp[ cnt ][ len ] != -1 ) printf ( "YES\n" ) ;
else printf ( "NO\n" ) ;
}
return 0 ;
}