Codeforces 750E 线段树DP
题意:给你一个字符串,有两种操作:1:把某个位置的字符改变。2:询问l到r的子串最少需要删除多少个字符,使得这个子串含有2017子序列,并且没有2016子序列?
思路:线段树上DP,我们设状态0, 1, 2, 3, 4分别为: null, 2, 20, 201, 2017的最小花费,我们用线段树来维互状态转移的花费矩阵,合并相邻的两个子串的时候直接转移即可。
代码:
#include <bits/stdc++.h> #define INF 0x3f3f3f3f #define ls (o << 1) #define rs (o << 1 | 1) using namespace std; const int maxn = 200010; int a[maxn]; char s[maxn]; struct node { int f[5][5]; void init(int x) { for (int i = 0; i < 5; i++) { for (int j = 0; j < 5; j++) { if(i == j) continue; f[i][j] = INF; } } if(x == 2) { f[0][0] = 1, f[0][1] = 0; } else if (x == 0) { f[1][1] = 1, f[1][2] = 0; } else if (x == 1) { f[2][2] = 1, f[2][3] = 0; } else if (x == 7) { f[3][3] = 1, f[3][4] = 0; } else if (x == 6) { f[3][3] = 1; f[4][4] = 1; } else if (x == -1){ for (int i = 0; i < 5; i++) f[i][i] = INF; } else { for (int i = 0; i < 5; i++) f[i][i] = 0; } } void print() { for (int i = 0; i < 5; i++) { for (int j = 0; j < 5; j++) { if(f[i][j] == INF) printf("inf "); else printf("%d ", f[i][j]); } printf("\n"); } } }; node tr[maxn * 4]; node merge(node t1, node t2) { node ans; ans.init(-1); // ans.init(-1); // printf("ans\n"); // ans.print(); // printf("t1\n"); // t1.print(); // printf("t2\n"); // t2.print(); for (int i = 0; i < 5; i++) { for (int j = i; j < 5; j++) { for (int k = i; k <= j; k++) { ans.f[i][j] = min(ans.f[i][j], t1.f[i][k] + t2.f[k][j]); } } } // printf("ans\n"); // ans.print(); return ans; } void build(int o, int l, int r) { if(l == r) { tr[o].init(a[l]); return; } int mid = (l + r) >> 1; build(ls, l, mid); build(rs, mid + 1, r); tr[o] = merge(tr[ls], tr[rs]); } void update(int o, int l, int r, int ql, int qr, int val) { if(l == r) { tr[o].init(val); return; } int mid = (l + r) >> 1; if(ql <= mid) update(ls, l, mid, ql, qr, val); if(qr > mid) update(rs, mid + 1, r, ql, qr, val); tr[o] = merge(tr[ls], tr[rs]); } node query(int o, int l, int r, int ql, int qr) { if(l >= ql && r <= qr) { return tr[o]; } int mid = (l + r) >> 1; node ans; ans.init(-1); if(ql <= mid && qr > mid) ans = merge(query(ls, l, mid, ql, qr), query(rs, mid + 1, r, ql, qr)); else if(ql <= mid) ans = query(ls, l, mid, ql, qr); else if(qr > mid) ans = query(rs, mid + 1, r, ql, qr); return ans; } int main() { int n, m, l, r; scanf("%d%d", &n, &m); scanf("%s", s + 1); for (int i = 1; i <= n; i++) { a[i] = s[i] - '0'; } build(1, 1, n); for (int i = 1; i <= m; i++) { scanf("%d%d", &l, &r); node ans = query(1, 1, n, l, r); if(ans.f[0][4] == INF) printf("-1\n"); else printf("%d\n", ans.f[0][4]); } }