2017北理校赛H题 青蛙过河(线段树, dp, 离散化)
题面:
这题的O(n^2) dp很容易想出来,由于我没有数据,也没法提交测试。所以就拿这个dp来简单对拍了。
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 const int maxn = 1010; 5 int dp[maxn]; 6 int n, c; 7 int a[maxn]; 8 9 int main() { 10 // freopen("in", "r", stdin); 11 while(~scanf("%d%d",&n,&c)) { 12 memset(dp, 0, sizeof(dp)); 13 for(int i = 1; i <= n; i++) scanf("%d", &a[i]); 14 dp[1] = 1; 15 for(int i = 2; i <= n; i++) { 16 dp[i] = 1; 17 for(int j = 1; j < i; j++) { 18 if((int)abs(a[i]-a[j]) >= c) dp[i] = max(dp[i], dp[j]+1); 19 } 20 } 21 printf("%d\n", dp[n]); 22 } 23 return 0; 24 }
这道题的n非常大,显然这个O(n^2)的dp是非常不给力的,考虑转移状态:dp(i)一定由之前的dp(j)转移过来,并且这个dp(j)一定是a(j)符合条件并且dp(j)是最大的。可以用线段树维护某个数字x的[-inf,x-c]与[x+c,inf]的dp值,这样查询的时候可以直接查询这两个区间,更新的时候只需要更新x点的dp值就行了。
数据范围很大,有1e9,因此要离散化。离散化也很好操作,考虑查询,每一个点有2个查询的点:x-c和x+c,加上-inf和inf再加上x,把这几个部分离散化就行了。
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 #define lrt rt << 1 5 #define rrt rt << 1 | 1 6 typedef long long LL; 7 typedef struct Seg { 8 LL sign, val; 9 }Seg; 10 const LL maxn = 300300; 11 const LL inf = 2000000000LL; 12 Seg seg[maxn<<3]; 13 LL n, c; 14 LL a[maxn], h[maxn*4], hcnt; 15 LL id(LL x) { return lower_bound(h, h+hcnt, x) - h + 1; } 16 void pushup(LL rt) { seg[rt].val = max(seg[lrt].val, seg[rrt].val); } 17 18 void build(LL l, LL r, LL rt) { 19 seg[rt].sign = -1; seg[rt].val = 0; 20 if(l == r) return; 21 LL mid = (l + r) >> 1; 22 build(l, mid, lrt); 23 build(mid+1, r, rrt); 24 } 25 26 void pushdown(LL rt) { 27 if(~seg[rt].sign) { 28 seg[lrt].sign = seg[rrt].sign = seg[rt].sign; 29 seg[lrt].val = seg[rrt].val = seg[rt].sign; 30 seg[rt].sign = -1; 31 } 32 } 33 34 void update(LL pos, LL val, LL l, LL r, LL rt) { 35 if(l == r) { 36 seg[rt].val = seg[rt].sign = val; 37 return; 38 } 39 pushdown(rt); 40 LL mid = (l + r) >> 1; 41 if(pos <= mid) update(pos, val, l, mid, lrt); 42 else update(pos, val, mid+1, r, rrt); 43 pushup(rt); 44 } 45 46 LL query(LL L, LL R, LL l, LL r, LL rt) { 47 if(L <= l && r <= R) return seg[rt].val; 48 pushdown(rt); 49 LL mid = (l + r) >> 1; 50 LL ret = -1; 51 if(L <= mid) ret = max(ret, query(L, R, l, mid, lrt)); 52 if(mid < R) ret = max(ret, query(L, R, mid+1, r, rrt)); 53 return ret; 54 } 55 56 int main() { 57 // freopen("in", "r", stdin); 58 while(~scanf("%lld%lld",&n,&c)) { 59 hcnt = 0; h[hcnt++] = -inf; h[hcnt++] = inf; 60 for(LL i = 1; i <= n; i++) { 61 scanf("%lld", &a[i]); 62 h[hcnt++] = a[i] - c; 63 h[hcnt++] = a[i]; 64 h[hcnt++] = a[i] + c; 65 } 66 sort(h, h+hcnt); hcnt = unique(h, h+hcnt) - h; 67 build(1, hcnt, 1); 68 LL L = id(-inf), R = id(inf); 69 LL dp = -1; 70 for(LL i = 1; i <= n; i++) { 71 LL l = id(a[i]-c), r = id(a[i]+c); 72 dp = max(query(L, l, 1, hcnt, 1), query(r, R, 1, hcnt, 1)) + 1; 73 update(id(a[i]), dp, 1, hcnt, 1); 74 } 75 printf("%lld\n", dp); 76 } 77 return 0; 78 }