bzoj 4000 矩阵快速幂优化DP
建立矩阵,跑快速幂
1 /************************************************************** 2 Problem: 4000 3 User: idy002 4 Language: C++ 5 Result: Accepted 6 Time:32 ms 7 Memory:836 kb 8 ****************************************************************/ 9 10 #include <cstdio> 11 #include <cstring> 12 #define M 6 13 14 struct Matrix { 15 unsigned v[1<<M][1<<M]; 16 const unsigned *operator[]( int i ) const { return v[i]; } 17 }; 18 19 int n, m, p, k, bound; 20 int attack[3]; 21 int stat[1<<M], id[1<<M], stot; 22 Matrix base, dest; 23 24 void make_unit( Matrix &x ) { 25 for( int i=0; i<stot; i++ ) 26 for( int j=0; j<stot; j++ ) 27 x.v[i][j] = i==j; 28 } 29 Matrix operator*( const Matrix &a, const Matrix &b ) { 30 Matrix c; 31 for( int i=0; i<stot; i++ ) 32 for( int j=0; j<stot; j++ ) { 33 c.v[i][j] = 0; 34 for( int k=0; k<stot; k++ ) 35 c.v[i][j] += a[i][k]*b[k][j]; 36 } 37 return c; 38 } 39 Matrix mpow( Matrix a, int b ) { 40 Matrix rt; 41 for( make_unit(rt); b; b>>=1,a=a*a ) 42 if( b&1 ) rt=rt*a; 43 return rt; 44 } 45 int getarea( int s, int a ) { 46 int rt = 0; 47 for( int b=0; b<m; b++ ) { 48 if( (s>>b)&1 ) { 49 int aa = a; 50 if( b<k ) 51 aa >>= k-b; 52 else 53 aa <<= b-k; 54 aa &= bound; 55 rt |= aa; 56 } 57 } 58 return rt; 59 } 60 void build() { 61 stot = 0; 62 memset( id, -1, sizeof(id) ); 63 for( int s=0; s<=bound; s++ ) { 64 if( getarea(s,attack[1])&s ) continue; 65 stat[stot]=s; 66 id[s] = stot; 67 stot++; 68 } 69 for( int s1=0; s1<=bound; s1++ ) { 70 if( id[s1]==-1 ) continue; 71 for( int s2=0; s2<=bound; s2++ ) { 72 if( id[s2]==-1 ) continue; 73 if( getarea(s1,attack[2])&s2 ) continue; 74 if( getarea(s2,attack[0])&s1 ) continue; 75 base.v[id[s1]][id[s2]] = 1; 76 } 77 } 78 } 79 int main() { 80 scanf( "%d%d%d%d", &n, &m, &p, &k ); 81 bound = (1<<m)-1; 82 for( int i=0; i<3; i++ ) 83 for( int j=0,o; j<p; j++ ) { 84 scanf( "%d", &o ); 85 attack[i] = attack[i] | (o<<j); 86 } 87 attack[1] ^= 1<<k; 88 build(); 89 dest = mpow( base, n-1 ); 90 unsigned ans = 0; 91 for( int i=0; i<stot; i++ ) 92 for( int j=0; j<stot; j++ ) 93 ans += dest[i][j]; 94 printf( "%u\n", ans ); 95 }