HDU2227Find the nondecreasing subsequences(树状数组+DP)
题目大意就是说帮你给出一个序列a,让你求出它的非递减序列有多少个。
设dp[i]表示以a[i]结尾的非递减子序列的个数,由题意我们可以写出状态转移方程:
dp[i] = sum{dp[j] | 1<=j<i && a[j] <= a[i]} + 1.
这样一来这里面所有的dp[]值的和就是最后的结果。
但是这个状态转移方程很明显复杂度是O(n^2),但是n可以达到100000,很明显会超时。既然是求前导和,很明显我们就应该可以想到用树状数组(虽然我怎么也不可能想到==!),这样一来那么复杂度就可以降到O(nlogn)。
那么怎么求前导和呢??也并不是所有的dp[j](1<=j<i)都要被加进去啊,只有满足a[j]<=a[i]时dp值才可以被计算在内。。。
解决办法就是先将原数组复制一份,然后排序,然后再按照原顺序找出每一个数的排序后的所在位置,然后计算这个位置的dp[]值,可以通过一例看出他的正确性:
原数组: 8 5 3 4 1
排序后: 1 3 4 5 8
可以看出原数组的每一个数对应到排序后的下标就是:5 4 2 3 1
没有计算前,树状数组里的值全为0,然后
1、找到8的位置,并计算以‘8’结尾的dp[]的值,也就是计算‘8’在排序后所在位置5的值, 计算dp[5] = 0 + 1 = 1
2、然后找到‘5’在排序后的位置4,由于‘8’>‘5’,所以以‘5’结尾的dp值应该也是1,正好排序后‘5’在第4个,在‘8’前面,自然dp[4]计算出来还是1
3、同理,‘3’出现在第二个,dp[2] = 1
4、然后到‘4’,他在排序后出现在第3 个,而原数组中‘4’之前有一个数‘3’,所以计算出来以‘4’结尾的dp[]值应该就是以‘3’结尾的dp值+1等于2,而我们看排序后4出现在第3个,而第3个之前又正好有一个dp[2]=1已经被计算出来了,这样dp值的前导和就是1,从而dp[3] = dp[2] + 1 = 2.
5、最后dp[1] = 1.所以最后结果就是1+1+1+2+1 = 6
其实上面的排序找到下标就是为了保证每个数计算出来的值都是满足a[j] < a[i]时,所计算出来的dp值。这也就是原题的解。
而要实现找到原数组在排序后的位置,我们只需要二分查找就可以了,又因为原数组可能会有相同的数,为了找到的是同一个标号,所以需要二分查找下限(或者上限)。
1 #include <cstdio> 2 #include <cstring> 3 #include <algorithm> 4 using namespace std; 5 6 #define mem(a) memset(a,0,sizeof(a)) 7 #define mod 1000000007 8 #define MAXN 100010 9 10 int num[MAXN], d[MAXN], N, DP[MAXN]; 11 12 int lowbit(int x) 13 { 14 return x & -x; 15 } 16 17 int getSum(int k) 18 { 19 int ans = 0; 20 while(k>=1) 21 { 22 ans = (ans + DP[k]) % mod; 23 k -= lowbit(k); 24 } 25 return ans; 26 } 27 28 void edit(int k,int val) 29 { 30 while(k<=N) 31 { 32 DP[k] = (DP[k] + val) % mod; 33 k += lowbit(k); 34 } 35 } 36 37 38 int bsearch(int num) 39 { 40 int x = 1, y = N+1, mid; 41 while(y > x ) 42 { 43 mid = (x+y)/2; 44 if(d[mid] == num && d[mid-1]<num) return mid; 45 if(d[mid] >= num) y = mid; 46 else x = mid+1; 47 } 48 return mid; 49 } 50 51 int main() 52 { 53 while(~scanf("%d", &N)) 54 { 55 mem(DP); 56 mem(d); mem(num); 57 for(int i = 1; i <= N; i ++) 58 { 59 scanf("%d", &num[i]); 60 d[i] = num[i]; 61 } 62 sort(d+1, d+N+1); 63 for(int i=1;i<=N;i++) 64 { 65 int id = bsearch(num[i]); 66 edit(id, getSum(id)+1); 67 } 68 printf("%d\n", getSum(N)); 69 } 70 return 0; 71 }