bzoj 3120 矩阵优化DP

 

我的第一道需要程序建矩阵的矩阵优化DP。

 

题目可以将不同的p分开处理。

对于p==0 || p==1 直接是0或1

对于p>1,就要DP了。这里以p==3为例:

设dp[i][s1][s2][r]为前i列,结尾为0的有s1行(0表示女生,1表示男生),结尾为01的有s2个,结尾为011的有n-s1-s2个,有r列全是1的方案数。

状态这么复杂,看起来一点也不能用矩阵优化,但我们可以将状态(s1,s2,r)hash成整数,然后建立状态之间的转移。

 

收获:

这种m超过10^7的一般都要用矩阵优化,如果状态很复杂,可以将复杂的状态(但一般不多)hash成整数,然后用计算机建立状态之间的转移,然后就可以用矩阵优化了。

 

 

  1 /**************************************************************
  2     Problem: 3120
  3     User: idy002
  4     Language: C++
  5     Result: Accepted
  6     Time:24504 ms
  7     Memory:2616 kb
  8 ****************************************************************/
  9  
 10 #include <cstdio>
 11 #include <cstring>
 12 #define M 1000000007
 13 #define maxs 200
 14  
 15 typedef long long dint;
 16  
 17 int n, p, q;
 18 dint m;
 19 int comb[10][10];
 20  
 21 void init_comb() {
 22     for( int i=0; i<=9; i++ )
 23         for( int j=0; j<=i; j++ ) {
 24             if( j==0 || i==j )
 25                 comb[i][j]=1;
 26             else
 27                 comb[i][j] = (comb[i-1][j]+comb[i-1][j-1]);
 28         }
 29 }
 30 struct Matrix {
 31     int n;
 32     dint v[maxs][maxs];
 33     void init( int nn ) {
 34         n=nn;
 35         for( int i=0; i<n; i++ )
 36             for( int j=0; j<n; j++ )
 37                 v[i][j]=0;
 38     }
 39     void make_unit( int nn ) {
 40         n=nn;
 41         for( int i=0; i<n; i++ )
 42             for( int j=0; j<n; j++ )
 43                 v[i][j] = i==j;
 44     }
 45     Matrix operator*( const Matrix & b ) const {
 46         const Matrix & a = *this;
 47         Matrix rt;
 48         memset( &rt, 0, sizeof(rt) );
 49         rt.n = b.n;
 50         for( int k=0; k<n; k++ ) {
 51             for( int i=0; i<n; i++ ) {
 52                 if( a.v[i][k] ) {
 53                     for( int j=0; j<n; j++ ) {
 54                         if( b.v[k][j] ) {
 55                             rt.v[i][j] += (a.v[i][k]*b.v[k][j])%M;
 56                             if( rt.v[i][j]>=M ) rt.v[i][j]-=M;
 57                         }
 58                     }
 59                 }
 60             }
 61         }
 62         /*
 63         for( int i=0; i<n; i++ ) {
 64             for( int j=0; j<n; j++ ) {
 65                 rt.v[i][j] = 0;
 66                 for( int k=0; k<n; k++ ) {
 67                     rt.v[i][j] += (a.v[i][k]*b.v[k][j])%M;
 68                     if( rt.v[i][j]>=M )
 69                         rt.v[i][j]-=M;
 70                 }
 71             }
 72         }
 73         */
 74         return rt;
 75     }
 76 };
 77  
 78 Matrix mpow( Matrix a, dint b ) {
 79     Matrix rt;
 80     for( rt.make_unit(a.n); b; b>>=1,a=a*a )
 81         if( b&1 ) rt=(rt*a);
 82     return rt;
 83 }
 84 namespace Sec1 {
 85     void sov(){
 86         if( p==0 ) printf( "0\n" );
 87         else if( p==1 ) printf( "1\n" );
 88     }
 89 }
 90 namespace Sec2 {
 91     int info[maxs][3], idx[9][4], tot;
 92     Matrix trans;
 93     void init() {
 94         for( int r=0; r<=q; r++ )
 95             for( int a=0; a<=n; a++ ) {
 96                 info[tot][0] = a;
 97                 info[tot][1] = n-a;
 98                 info[tot][2] = r;
 99                 idx[a][r] = tot;
100                 tot++;
101             }
102         trans.init(tot);
103         for( int s=0; s<tot; s++ ) {
104             int a=info[s][0], r=info[s][2];
105             for( int sa=0; sa<=a; sa++ ) {
106                 int na=n-sa, nr=r+(sa==n);
107                 if( nr>q ) continue;
108                 int ns=idx[na][nr];
109                 trans.v[s][ns] = comb[a][sa];
110             }
111         }
112     }
113     void sov() {
114         init();
115         trans = mpow( trans, m );
116         int row = idx[n][0];
117         dint ans=0;
118         for( int i=0; i<tot; i++ ) {
119             ans += trans.v[row][i];
120             if( ans>=M ) ans-=M;
121         }
122         printf( "%lld\n", ans );
123     }
124 }
125 namespace Sec3 {
126     int info[maxs][4], idx[9][9][4], tot;
127     Matrix trans;
128     void init() {
129         for( int r=0; r<=q; r++ )
130             for( int a=0; a<=n; a++ )
131                 for( int b=0; b<=n-a; b++ ) {
132                     info[tot][0] = a;
133                     info[tot][1] = b;
134                     info[tot][2] = n-a-b;
135                     info[tot][3] = r;
136                     idx[a][b][r] = tot;
137                     tot++;
138                 }
139         trans.n = tot;
140         for( int s=0; s<tot; s++ ) {
141             int a=info[s][0], b=info[s][1], c=info[s][2], r=info[s][3];
142             for( int sa=0; sa<=a; sa++ )
143                 for( int sb=0; sb<=b; sb++ ) {
144                     int na=c+a-sa+b-sb, nb=sa, nr=r+(sa+sb==n);
145                     if( nr>q ) continue;
146                     int ns = idx[na][nb][nr];
147                     trans.v[s][ns] = comb[a][sa]*comb[b][sb];
148                 }
149         }
150     }
151     void sov(){
152         init();
153         trans = mpow( trans, m );
154         int row = idx[n][0][0];
155         dint ans=0;
156         for( int i=0; i<tot; i++ ) {
157             ans += trans.v[row][i];
158             if( ans>=M ) ans-=M;
159         }
160         printf( "%lld\n", ans );
161     }
162 }
163  
164 int main() {
165     scanf( "%d%lld%d%d", &n, &m, &p, &q );
166     init_comb();
167     if( p<=1 ) Sec1::sov();
168     else if( p==2 ) Sec2::sov();
169     else if( p==3 ) Sec3::sov();
170 }
View Code

 

posted @ 2015-03-05 16:17  idy002  阅读(297)  评论(0编辑  收藏  举报