数位DP

学了一下怎么写递归,发现确实比较简单;

dp[pos][][]对应dfs()中的参数的状态,记忆化当前状态的值,不用考虑这个状态表示什么意思;

然后就是设计好dfs()中的参数;

hdu 3555 http://acm.hdu.edu.cn/showproblem.php?pid=3555

题意:统计1~n之间含有49的数字的个数;

需要记录当前位置,前一位置放了那个数字,当前是否已经包含49,是否有上界;

dfs(pos,pre,istrue,limit);

 1 #include<cstdio>
 2 #include<cstring>
 3 #include<iostream>
 4 #include<algorithm>
 5 #include<iostream>
 6 using  namespace std;
 7 const int N = 20;
 8 typedef long long LL;
 9 int dig[N];
10 LL dp[N][10][2];
11 
12 LL dfs(int pos,int pre,int istrue,int limit) {
13     if (pos < 0) return istrue;
14     if (!limit && dp[pos][pre][istrue] != -1)
15         return dp[pos][pre][istrue];
16     int last = limit ? dig[pos] : 9;
17     LL ret = 0;
18     for (int i = 0; i <= last; i++) {
19         ret += dfs(pos-1,i,istrue || (pre == 4 && i == 9),limit && (i == last));
20     }
21     if (!limit) {
22         dp[pos][pre][istrue] = ret;
23     }
24     return ret;
25 }
26 LL solve(LL n) {
27     int len = 0;
28     while (n) {
29         dig[len++] = n % 10;
30         n /= 10;
31     }
32     return dfs(len-1,0,0,1);
33 }
34 int main(){
35     memset(dp,-1,sizeof(dp));
36     int T; scanf("%d",&T);
37     while (T--) {
38         LL n;
39         cin>>n;
40         cout<<solve(n)<<endl;
41     }
42     return 0;
43 }
View Code

 

 usetc 1307 http://acm.uestc.edu.cn/problem.php?pid=1307

题意:相邻两个数之差大于2;

dfs(pos,pre,limit,fg); fg表示前面是否全为0

 1 #include<cstdio>
 2 #include<iostream>
 3 #include<algorithm>
 4 #include<cmath>
 5 #include<cstdlib>
 6 #include<cstring>
 7 #include<iostream>
 8 using namespace std;
 9 typedef long long LL;
10 const int N = 20;
11 LL dp[N][10][2];
12 LL a,b;
13 int dig[N];
14 LL dfs(int pos,int pre,int limit,int fg) {
15     if (pos < 0) return 1;
16     if (!limit && dp[pos][pre][fg] != -1)
17         return dp[pos][pre][fg];
18     int last = limit ? dig[pos] : 9;
19     LL ret = 0;
20     for (int i = 0; i <= last; i++) {
21         if (fg == 0 || abs(i - pre) >= 2)
22         ret += dfs(pos-1,i,limit && (i == last),fg || i);
23     }
24     if (!limit) {
25         dp[pos][pre][fg] = ret;
26     }
27     return ret;
28 }
29 LL solve(LL n) {
30     int len = 0;
31     if (n < 0) return 0;
32     while (n) {
33         dig[len++] = n % 10;
34         n /= 10;
35     }
36     return dfs(len-1,0,1,0);
37 }
38 int main(){
39     memset(dp,-1,sizeof(dp));
40   //  cout<<solve(15)<<endl;
41     while (cin>>a>>b) {
42         cout<<solve(b)-solve(a-1)<<endl;
43     }
44     return 0;
45 }
View Code

 

hdu4352 http://acm.hdu.edu.cn/showproblem.php?pid=4352

题意:求[L,R]内最长递增子序列是k的个数;

分析:知道题意后,马上map<vector<>,LL> mp[][]搞了,然后华丽丽的T掉了,

vector<int> g,  g[i]表示最长递增子序列为长度为i结尾的值的最小值;

这是我对vector<>里面的值的性质没有思考,显然vector<>最多包含10个数,并且是严格递增的,

这样我们就可以直接状压存了,(1<<10)一个数值跟vector<>是一一对应的;

 1 #include<cstdio>
 2 #include<cstring>
 3 #include<iostream>
 4 #include<cmath>
 5 #include<algorithm>
 6 #include<queue>
 7 #include<vector>
 8 using namespace std;
 9 typedef long long LL;
10 LL dp[22][1<<10][12];
11 LL a,b;
12 int k;
13 int dig[22];
14 int cge(int sta,int k) {
15     if (sta & (1<<k)) return sta;
16     if ((1<<k) > sta) {
17         return sta | (1<<k);
18     }
19     sta |= 1<<k;
20     for (int i = k+1; i < 10; i++) {
21         if (sta & (1<<i)) {
22             return sta ^ (1<<i);
23         }
24     }
25 }
26 int get(int k) {
27     int ret = 0;
28     for (int i = 0; i < 10; i++) if (k & (1<<i)) ret++;
29     return ret;
30 }
31 LL dfs(int pos,int sta,int limit) {
32     if (pos < 0) return get(sta) == k;
33     if (!limit && dp[pos][sta][k] != -1)
34         return dp[pos][sta][k];
35     int last = limit ? dig[pos] : 9;
36     LL ret = 0;
37     for (int i = 0; i <= last; i++) {
38         ret += dfs(pos-1,sta || i ? cge(sta,i) : 0,limit && (i == last));
39     }
40     if (!limit) {
41         dp[pos][sta][k] = ret;
42     }
43     return ret;
44 }
45 LL solve(LL n) {
46     int len = 0;
47     while (n) {
48         dig[len++] = n % 10;
49         n /= 10;
50     }
51     return dfs(len-1,0,1);
52 }
53 int main(){
54     memset(dp,-1,sizeof(dp));
55     int T,cas = 0; scanf("%d",&T);
56     while (T--) {
57         scanf("%I64d%I64d%d",&a,&b,&k);
58         printf("Case #%d: ",++cas);
59         printf("%I64d\n",solve(b) - solve(a-1));
60     }
61     return 0;
62 }
View Code

 

hdu3886 http://acm.hdu.edu.cn/showproblem.php?pid=3886

题意:求[l,r]内满足题意条件的数的个数,给你一个字符串这里且称为标准串,要数值满足这个标准串(条件比较难以表述,看题);

分析:dfs(pos,pre,loc,cc,limit,fg);

分别表示当前位置,前一位的数值,当前匹配到标准串中位置,跟标准串中匹配了几个数值,是否有限制,标记有无前导零;

有个trick,找了好久,如果直接来的话,比如// 1234 1234会输出2,因为12 34 @ 123 4 被当成不同的数了,我们可以规定如果标准串中有连续相同的字符,并且满足转移到下个字符了的话一定先转移;

 1 #include<cstdio>
 2 #include<cstring>
 3 #include<iostream>
 4 #include<algorithm>
 5 #include<cstdlib>
 6 #include<cmath>
 7 using namespace std;
 8 typedef long long LL;
 9 const int N = 100+10;
10 const LL Mod = 100000000;
11 char a[N],b[N];
12 char stand[N];
13 int flag;
14 void init(){
15     int lena = strlen(a), lenb = strlen(b);
16     reverse(a,a+lena);
17     reverse(b,b+lenb);
18     flag = 1;
19     for (int i = 0; i < lena; i++) if (a[i] != '0') flag = 0;
20     if (lena == 1 && a[0] == '0') flag = 1;
21     int mark = 1;
22     if (!flag)
23     for (int i = 0; i < lena; i++) {
24         if (mark) {
25             if (a[i] > '0') {
26                 a[i] = a[i] - 1;
27                 mark = 0;
28                 break;
29             }
30             else {
31                 a[i] = '9';
32                 mark = 1;
33             }
34         }
35     }
36 }
37 int dig[N];
38 int end;
39 int dp[N][10][N][2][2];
40 
41 int check(int pre,int nw,int id) {
42     if (id >= end) return 0;
43     if (pre < nw && stand[id] == '/') return 1;
44     if (pre == nw && stand[id] == '-') return 1;
45     if (pre > nw && stand[id] == '\\') return 1;
46     return 0;
47 
48 }
49 int dfs(int pos,int pre,int loc,int cc,int limit,int fg) {
50     if (pos < 0) return loc == end - 1 && cc >= 2;
51     int t = 0;
52     if (cc >= 2) t = 1;
53     if (!limit && dp[pos][pre][loc][t][fg] != -1)
54         return dp[pos][pre][loc][t][fg];
55     int last = limit ? dig[pos] : 9;
56     int ret = 0;
57     for (int i = 0; i <= last; i++) {
58         if (!fg) {
59             ret = (ret + dfs(pos-1,i,loc,i != 0,limit && (i==last),fg || i)) % Mod;
60             continue;
61         }
62         if (cc >= 2 &&  check(pre,i,loc+1) && stand[loc] == stand[loc+1]){
63             ret = (ret + dfs(pos-1,i,loc+1,2,limit && (i==last),fg || i)) % Mod;
64             continue;
65         }
66 
67         if (cc >= 2 &&  check(pre,i,loc+1) && stand[loc] != stand[loc+1]){
68             ret = (ret + dfs(pos-1,i,loc+1,2,limit && (i==last),fg || i)) % Mod;
69         }
70         if (check(pre,i,loc))
71             ret = (ret + dfs(pos-1,i,loc,cc+1,limit && (i==last),fg || i) ) % Mod;
72     }
73     if (!limit ) {
74         dp[pos][pre][loc][t][fg] = ret;
75     }
76     return ret;
77 }
78 int solve(char *s) {
79     int len = strlen(s);
80     for (int i = 0; i < len; i++) {
81         dig[i] = s[i] - '0';
82     }
83     while (dig[len-1] == 0) len--;
84    // dig[len++] = 0;
85     return dfs(len-1,0,0,0,1,0);
86 }
87 int main(){
88     while (~scanf("%s%s%s",stand,a,b)) {
89         init();
90         memset(dp,-1,sizeof(dp));
91         end = strlen(stand);
92       //  cout<<a<<endl;
93       //  cout<<b<<endl;
94         if (flag) printf("%08d\n",solve(b));
95         else printf("%08d\n",(solve(b) - solve(a) + Mod) % Mod);
96     }
97     return 0;
98 }
View Code

 

cf 55D http://codeforces.com/problemset/problem/55/D

题意:求[L,R]之间能整除自己每一位的数的个数;

分析:1~9的最小公倍数为2520,同时记录下那些数出现过因为0,1不许要1<<8,cf内存大

dfs(pos,sta,mod,limit);

 1 #include<cstdio>
 2 #include<cstring>
 3 #include<iostream>
 4 #include<algorithm>
 5 #include<cmath>
 6 using namespace std;
 7 typedef long long LL;
 8 const int Mod = 2520;
 9 LL dp[20][1<<8][2520];
10 
11 LL a,b;
12 int dig[20];
13 int check(int sta,int mod) {
14     for (int i = 2; i < 10; i++) {
15         if (sta & (1<<(i-2))) {
16             if (mod % i) return 0;
17         }
18     }
19     return 1;
20 }
21 LL dfs(int pos,int sta,int mod,int limit) {
22     if (pos < 0) return check(sta,mod);
23     if (!limit && dp[pos][sta][mod] != -1) return dp[pos][sta][mod];
24     int last = limit ? dig[pos] : 9;
25     LL ret = 0;
26     for (int i = 0; i <= last; i++) {
27         int t = sta;
28         if (i >= 2) t |= 1<<(i-2);
29         ret += dfs(pos-1,t,(mod * 10 + i) % Mod,limit && (i == last));
30     }
31     if (!limit) {
32         dp[pos][sta][mod] = ret;
33     }
34     return ret;
35 }
36 LL solve(LL n) {
37     int len = 0;
38     while (n) {
39         dig[len++] = n % 10;
40         n /= 10;
41     }
42     return dfs(len-1,0,0,1);
43 }
44 int main(){
45     memset(dp,-1,sizeof(dp));
46     int T; scanf("%d",&T);
47     while (T--) {
48         cin>>a>>b;
49         cout<<solve(b) - solve(a-1)<<endl;
50     }
51     return 0;
52 }
View Code

 

 Foj 2042 http://acm.fzu.edu.cn/problem.php?pid=2042

题意:求[a,b]与[c,d]之间 xor值大于e的 sum += i ^ j;

分析:

dfs(pos,limita,limitb,limitc);

如果只是记录dp[pos]的话会TLE,所以记录dp[pos][limita][limitb][limitc];

这样的话每次solve()都要初始化;

递归版本的数位Dp记录真的很灵活,全看题目;

 1 #include<cstdio>
 2 #include<cstring>
 3 #include<algorithm>
 4 #include<iostream>
 5 #define MP make_pair
 6 using namespace std;
 7 const int N = 66;
 8 typedef long long LL;
 9 const LL Mod = 1000000007;
10 typedef pair<LL,LL> pLL;
11 pLL dp[N][2][2][2];
12 LL a,b,c,d,e;
13 int diga[N],digb[N],digc[N];
14 void change(int dig[],LL n,int &len) {
15     len = 0;
16     while (n) {
17         dig[len++] = n % 2;
18         n /= 2;
19     }
20 }
21 pLL dfs(int pos,int limita,int limitb,int limitc) {
22     if (pos < 0) {
23           return !limitc ? MP(1,0) : MP(0,0);
24     }
25     if (dp[pos][limita][limitb][limitc].first != -1 && dp[pos][limita][limitb][limitc].second != -1) return dp[pos][limita][limitb][limitc];
26     int lasta = limita ? diga[pos] : 1;
27     int lastb = limitb ? digb[pos] : 1;
28     int lastc = limitc ? digc[pos] : -1;
29     pLL ret = MP(0,0);
30     for (int i = 0; i <= lasta; i++) {
31         for (int j = 0; j <= lastb; j++) {
32             int t = (i ^ j);
33             if (t >= lastc) {
34                pLL cnt = dfs(pos-1,limita && (i == lasta),limitb && (j == lastb),limitc && (t == lastc));
35                ret.first = (ret.first + cnt.first) % Mod;
36                ret.second = ((ret.second + (1ll<<pos) * (i^j)  % Mod * cnt.first % Mod) % Mod + cnt.second) % Mod;
37             }
38         }
39     }
40     dp[pos][limita][limitb][limitc] = ret;
41 
42     return ret;
43 }
44 
45 LL solve(LL a,LL b,LL c) {
46         for (int i = 0; i < N; i++)
47               for (int j = 0; j < 2; j++)
48                   for (int k = 0; k < 2; k++)
49                   for (int x = 0; x < 2; x++) dp[i][j][k][x] = MP(-1,-1);
50 
51     int lena,lenb,lenc;
52     change(diga,a,lena);
53     change(digb,b,lenb);
54     change(digc,c,lenc);
55     int len = max(lena,max(lenb,lenc));
56     while (lena < len) diga[lena++] = 0;
57     while (lenb < len) digb[lenb++] = 0;
58     while (lenc < len) digc[lenc++] = 0;
59  
60     return (dfs(len-1,1,1,1).second + Mod) % Mod;
61 }
62 int main(){
63     int T,cas = 0; scanf("%d",&T);
64     while (T--) {
65 
66         printf("Case %d: ",++cas);
67         cin>>a>>b>>c>>d>>e;
68 
69         cout<<((solve(b,d,e) - solve(b,c-1,e) + Mod) % Mod  - solve(a-1,d,e) + solve(a-1,c-1,e) + Mod) % Mod<<endl;
70     }
71     return 0;
72 }
View Code

 

 

 

posted @ 2013-11-11 20:19  Rabbit_hair  阅读(5674)  评论(0编辑  收藏  举报