【BZOJ4565】【HAOI2016】字符合并 [状压DP][区间DP]

字符合并

Time Limit: 20 Sec  Memory Limit: 256 MB
[Submit][Status][Discuss]

Description

  有一个长度为 n 的 01 串,你可以每次将相邻的 k 个字符合并,得到一个新的字符并获得一定分数。
  得到的新字符和分数由这 k 个字符确定。你需要求出你能获得的最大分数。

Input

  第一行两个整数n,k。接下来一行长度为n的01串,表示初始串。
  接下来2^k行,每行一个字符ci和一个整数wi,
  ci表示长度为k的01串连成二进制后按从小到大顺序得到的第i种合并方案得到的新字符,
  wi表示对应的第i种方案对应获得的分数。

Output

  输出一个整数表示答案

Sample Input

  3 2
  101
  1 10
  1 10
  0 20
  1 30

Sample Output

  40

HINT

  1<=n<=300 ,0<=ci<=1, wi>=1, k<=8

Solution

  我们显然考虑区间DP,再状态压缩一下,f[l][r][opt]表示[l, r]合成了opt最大价值

  如果一个区间长度为len的话,最后合完会长度会变为len % (k - 1)

  转移的本质是把长度为k的区间变成0/1,分情况处理。

    先枚举每一个断点pos,表示我们要把[pos, r]合成一个0/1,那么就要保证(r - pos + 1) % (k - 1) = 1,否则我们DP的时候,会把000看做是0一样转移,导致不能合成为一个0/1的合成了。

    若len % (k -1) = 1,则合成完会剩下一个数,我们判断一下[l, r]能否合成一个opt的状态,若可以,则f[l][r][c[opt]] = max(f[l][r][opt] + val[opt])。注意要先拿一个变量记录下来,不能直接更新,否则会出现0状态更新了1,然后1又用0更新了的情况,导致答案过大。

  最后答案显然就是max(f[1][n][opt])

Code

 1 #include<iostream>
 2 #include<string>
 3 #include<algorithm>
 4 #include<cstdio>
 5 #include<cstring>
 6 #include<cstdlib>
 7 #include<cmath>
 8 using namespace std;
 9 typedef long long s64;
10  
11 const int ONE = 305;
12 const int MOD = 1e9 + 7;
13  
14 int n, k;
15 int total;
16 int a[ONE];
17 char s[ONE];
18 int c[ONE], val[ONE];
19 s64 f[ONE][ONE][ONE];
20 s64 Ans;
21  
22 int get()
23 {
24         int res=1,Q=1;  char c;
25         while( (c=getchar())<48 || c>57)
26         if(c=='-')Q=-1;
27         if(Q) res=c-48; 
28         while((c=getchar())>=48 && c<=57) 
29         res=res*10+c-48;
30         return res*Q; 
31 }
32  
33 int main()
34 {
35         n = get();  k = get(); total = (1 << k) - 1;
36  
37         for(int i = 1; i <= n; i++)
38             for(int j = 1; j <= n; j++)
39                 for(int opt = 0; opt <= total; opt++)
40                     f[i][j][opt] = -1;
41  
42         scanf("%s", s + 1);
43         for(int i = 1; i <= n; i++)
44             a[i] = s[i] - '0', f[i][i][a[i]] = 0;
45  
46         for(int i = 0; i <= total; i++)
47             c[i] = get(), val[i] = get();
48  
49         for(int l = n; l >= 1; l--)
50             for(int r = l; r <= n; r++)
51                 {
52                     if(l == r) continue;
53  
54                     for(int pos = r - 1; pos >= l; pos -= k - 1)
55                         for(int opt = 0; opt <= total; opt++)
56                         {
57                             if(f[l][pos][opt] == -1) continue;
58                             if(f[pos + 1][r][0] != -1 && (opt << 1) <= total)
59                                 f[l][r][opt << 1] = max(f[l][r][opt << 1], f[l][pos][opt] + f[pos + 1][r][0]);
60                             if(f[pos + 1][r][1] != -1 && (opt << 1 | 1) <= total)
61                                 f[l][r][opt << 1 | 1] = max(f[l][r][opt << 1 | 1], f[l][pos][opt] + f[pos + 1][r][1]);
62                         }
63  
64                     if((r - l + 1) % (k - 1) == 1 || k == 2)
65                     {
66                         s64 A = -1, B = -1;
67                         for(int opt = 0; opt <= total; opt++)
68                             if(f[l][r][opt] != -1)
69                             {
70                                 if(c[opt] == 0) A = max(A, f[l][r][opt] + val[opt]);
71                                 if(c[opt] == 1) B = max(B, f[l][r][opt] + val[opt]);
72                             }
73  
74                         f[l][r][0] = max(f[l][r][0], A);
75                         f[l][r][1] = max(f[l][r][1], B);
76                     }
77                 }
78  
79         for(int opt = 0; opt <= total; opt++)
80             Ans = max(Ans, f[1][n][opt]);
81  
82         printf("%lld", Ans);
83  
84 }
85 
View Code

 

  • 制作
posted @ 2017-10-31 17:25  BearChild  阅读(297)  评论(0编辑  收藏  举报