HDU4352 XHXJ's LIS 题解 数位DP
题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4352
题目大意:
求区间 \([L,R]\) 范围内最长上升子序列(Longest increasing subsequence,简称LIS)长度为 \(k\) 的数的数量。
举个例子:
\(123\) 的LIS只有一个\(123\),所以它的LIS的长度是 \(3\);
\(101\) 的LIS只有一个\(01\),所以它的LIS的长度是 \(2\);
\(132\) 的LIS有\(13\)和\(12\),所以它的LIS的长度是 \(2\)。
现在每次给你三个数 \(L,R,k\) ,你要求区间 \([L,R]\) 范围内LIS长度为 \(k\) 的数有多少个。
解题思路:
本题使用 数位DP 进行求解。
但是我觉得比较必要的先决条件是:你要对如何 使用二分的方法求解LIS 有一个比较深刻的理解!
虽然这并不是必须的,但是这能够帮助你理解状态转移的过程。
设状态 \(dp[pos][sta][k]\) 表示对于当前的这个 \(k\):
- 当前所处的数位为 \(pos\),
- 当前LIS的状态为 \(sta\)
时的数量。
\(sta\) 涉及状态压缩的思想,他表示当前LIS中的元素由哪些组成。
一开始初始时候的 \(sta\) 为 \(0000000000\)(10个\(0\))。
在某一阶段,
如果当前已经选择了 \(a[0]\), \(a[1]\) 和 \(a[3]\) ,那么当前的状态就是 \(0000001011\);
如果当前已经选择了 \(a[2]\), \(a[4]\) 和 \(a[7]\) ,那么当前的状态就是 \(0010010100\)。
接下来我们来举例一个数 \(15234\) 来演示我们数位DP的过程:
初始时 \(sta\) 为 \(0000000000\);
加入 \(1\) ,此时状态变成 \(0000000010\);
加入 \(5\) ,此时状态变成 \(0000100010\);
加入 \(2\) ,此时状态变成 \(0000000110\),
注意:这里是最重要的地方!!
为什么加入 \(2\) 之后 \(5\) 对应的位置会变成 \(0\) 呢?
因为我们这里记录的状态就是我们二分LIS对应的状态,
刚加入 \(5\) 的时候,状态是 \(0000100010\) ,它表示新加入的元素要构成一个长度为 \(2\) 的LIS,必须比 \(1\) 大,
新加入的元素要构成一个长度为 \(3\) 的LIS,必须比 \(5\) 大。
而加入 \(2\) 之后,情况就大大改观了,因为此时要构成一个长度为 \(2\) 的LIS,只需要比 \(2\) 大就可以了。
所以对于当前状态 \(sta\) 和当前数位要放的数字 \(i\) ,
如果 \(sta\) 的第 \(i\) 位为 \(1\) ,那么新的状态仍旧是 \(sta\)(因为LIS中存在 \(i\));
如果 \(sta\) 的第 \(i\) 为为 \(0\) ,那么:
- 如果 \(sta\) (没有特别强调都是指二进制)中没有任何比 \(i\) 大的位上为 \(1\) ,则新状态就是
sta | (1<<i)
; - 否则,将比 \(i\) 大的最小的那位置为 \(0\),再将第 \(i\) 位置为 \(1\),就是新的状态。
实现代码如下:
#include <bits/stdc++.h>
using namespace std;
long long f[22][1030][10];
int n, k, a[22];
void init() {
memset(f, -1, sizeof(f));
}
int new_sta(int pos, int sta, int i) {
if (!sta && i==0 && pos>0) return 0;
if (!(sta>>(i+1)) || (sta&(1<<i))) return sta | (1<<i);
for (int k = k = i+1; k < 10; k ++) if (sta & (1<<k)) return (sta ^ (1<<k)) | (1<<i);
}
long long dfs(int pos, int sta, bool limit) {
if (pos < 0) return __builtin_popcount(sta) == k ? 1 : 0;
if (!limit && f[pos][sta][k] != -1) return f[pos][sta][k];
int up = limit ? a[pos] : 9;
long long tmp = 0;
for (int i = 0; i <= up; i ++) {
tmp += dfs(pos-1, new_sta(pos, sta, i), limit && i==up);
}
if (!limit) f[pos][sta][k] = tmp;
return tmp;
}
long long get_num(long long x) {
int pos = 0;
while (x) {
a[pos++] = x % 10;
x /= 10;
}
return dfs(pos-1, 0, true);
}
int T;
long long L, R;
int main() {
init();
scanf("%d", &T);
for (int cas = 1; cas <= T; cas ++) {
scanf("%lld%lld%d", &L, &R, &k);
printf("Case #%d: %lld\n", cas, get_num(R) - get_num(L-1));
}
return 0;
}