【学习笔记】(27) 整体 DP
1.算法简介
整体 DP 就是用线段树合并维护 DP。
有一些问题,通常见于二维的DP,有一维记录当前x的信息,但是这一维过大无法开下,O(nm) 也无法通过。
但是如果发现,对于 x,在第二维的一些区间内,取值都是相同的,并且这样的区间是有限个,就可以批量处理。
所以我们就可以用线段树来维护 DP。
对于序列的问题,可以直接扫过去,修改某些位置的点,或者线段树合并。
对于树上的问题,线段树合并。
2.例题
Ⅰ. P4577 [FJOI2018] 领导集团问题
设
-
-
在合并完儿子后进行此操作。
对于转移 1,显然直接线段树合并即可,对于转移 2 ,发现其实是区间 +,且
#include<bits/stdc++.h> #define pb push_back using namespace std; const int N = 2e5 + 67; int read(){ int x = 0, f = 1; char ch = getchar(); while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();} while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();} return x * f; } bool _u; int n, cnt; int a[N], b[N], rt[N]; int ls[N << 5], rs[N << 5], lz[N << 5], mn[N << 5]; vector<int> e[N]; void pushup(int x){ mn[x] = min(mn[ls[x]], mn[rs[x]]) + lz[x]; } void modify(int &x, int l, int r, int L, int R){ if(!x) x = ++cnt; if(L <= l && r <= R) return ++lz[x], ++mn[x], void(); int mid = (l + r) >> 1; if(L <= mid) modify(ls[x], l, mid, L, R); if(R > mid) modify(rs[x], mid + 1, r, L, R); pushup(x); } int query1(int x, int l, int r, int p){ if(l == r && !x) return lz[x]; int mid = (l + r) >> 1; if(p <= mid) return lz[x] + query1(ls[x], l, mid, p); else return lz[x] + query1(rs[x], mid + 1, r, p); } int query2(int x, int l, int r, int val){ if(l == r) return l; int mid = (l + r) >> 1; val -= lz[x]; if(mn[ls[x]] <= val) return query2(ls[x], l, mid, val); else return query2(rs[x], mid + 1, r, val); } int merge(int x, int y){ if(!x || !y) return x + y; ls[x] = merge(ls[x], ls[y]), rs[x] = merge(rs[x], rs[y]); lz[x] += lz[y], pushup(x); return x; } void dfs(int x){ for(auto y : e[x]) dfs(y), rt[x] = merge(rt[x], rt[y]); int tmp = query1(rt[x], 1, n, a[x]); modify(rt[x], 1, n, query2(rt[x], 1, n, tmp), a[x]); } bool _v; int main(){ cerr << abs(&_u - &_v) / 1048576.0 << " MB\n"; n = read(); for(int i = 1; i <= n; ++i) b[i] = a[i] = read(); sort(b + 1, b + 1 + n); for(int i = 1; i <= n; ++i) a[i] = lower_bound(b + 1, b + 1 + n, a[i]) - b; for(int i = 2, f; i <= n; ++i) f = read(), e[f].pb(i); dfs(1); printf("%d\n", query1(rt[1], 1, n, 1)); return 0; }
Ⅱ.CF490F Treeland Tour
#include<bits/stdc++.h> #define pb push_back using namespace std; const int N = 6e3 + 67, LIM = 1e6; int read(){ int x = 0, f = 1; char ch = getchar(); while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();} while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();} return x * f; } bool _u; int n, ans, cnt; int a[N], rt[N]; int ls[N << 5], rs[N << 5], pre[N << 5], suf[N << 5]; vector<int> e[N]; void modify(int &x, int l, int r, int p, int v, int *val){ if(!x) x = ++cnt; val[x] = max(val[x], v); if(l == r) return ; int mid = (l + r) >> 1; if(p <= mid) modify(ls[x], l, mid, p, v, val); else modify(rs[x], mid + 1, r, p, v, val); } int merge(int x, int y){ if(!x || !y) return x + y; pre[x] = max(pre[x], pre[y]), suf[x] = max(suf[x], suf[y]); ans = max(ans, max(pre[ls[x]] + suf[rs[y]], suf[rs[x]] + pre[ls[y]])); ls[x] = merge(ls[x], ls[y]), rs[x] = merge(rs[x], rs[y]); return x; } int query(int x, int l, int r, int L, int R, int *val){ if(!x || L > R) return 0; if(L <= l && r <= R) return val[x]; int mid = (l + r) >> 1, ans = 0; if(L <= mid) ans = max(ans, query(ls[x], l, mid, L, R, val)); if(R > mid) ans = max(ans, query(rs[x], mid + 1, r, L, R, val)); return ans; } void dfs(int x, int fa){ int ns = 0, np = 0; for(auto y : e[x]){ if(y == fa) continue; dfs(y, x); int tp = query(rt[y], 1, LIM, 1, a[x] - 1, pre); int ts = query(rt[y], 1, LIM, a[x] + 1, LIM, suf); ans = max(ans, max(np + ts, ns + tp) + 1); ns = max(ns, ts), np = max(np, tp); rt[x] = merge(rt[x], rt[y]); } modify(rt[x], 1, LIM, a[x], np + 1, pre); modify(rt[x], 1, LIM, a[x], ns + 1, suf); } bool _v; int main(){ cerr << abs(&_u - &_v) / 1048576.0 << " MB\n"; n = read(); for(int i = 1; i <= n; ++i) a[i] = read(); for(int i = 1; i < n; ++i){ int u = read(), v = read(); e[u].pb(v), e[v].pb(u); } dfs(1, 0); printf("%d\n", ans); return 0; }
Ⅲ. P6773 [NOI2020] 命运
发现对于 若干点对
所以我们可以设
-
的权值 为 , -
的权值 为 ,
设
那么转移式就可以写成
线段树合并即可。
#include<bits/stdc++.h> #define ll long long using namespace std; const int N = 5e5 + 67, mod = 998244353; int read(){ int x = 0, f = 1; char ch = getchar(); while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();} while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();} return x * f; } bool _u; int n, m, nod; int dep[N], ls[N << 5], rs[N << 5], rt[N]; ll laz[N << 5], sum[N << 5]; void add(ll &x, ll y){x += y; if(x >= mod) x -= mod;} void mul(ll &x, ll y){x = x * y % mod;} struct Edge{ int tot, to[N << 1], nxt[N << 1], hd[N]; void add(int u, int v){to[++tot] = v, nxt[tot] = hd[u], hd[u] = tot;} }e, g; void build(int &u, int l, int r, int p){ u = ++nod; sum[u] = laz[u] = 1; if(l == r) return ; int mid = (l + r) >> 1; if(p <= mid) build(ls[u], l, mid, p); else build(rs[u], mid + 1, r, p); } void pushup(int u){sum[u] = (sum[ls[u]] + sum[rs[u]]) % mod;} void pushdown(int u){ if(laz[u] != 1){ mul(sum[ls[u]], laz[u]), mul(sum[rs[u]], laz[u]); mul(laz[ls[u]], laz[u]), mul(laz[rs[u]], laz[u]); } laz[u] = 1; } ll query(int u, int l, int r, int p){ if(!u || r <= p) return sum[u]; int mid = (l + r) >> 1; ll ans = 0; pushdown(u); if(p > mid) ans = query(rs[u], mid + 1, r, p); return add(ans, query(ls[u], l, mid, p)), ans; } int merge(int x, int y, int l, int r, ll &s1, ll &s2){ //s1 表示 g[v][dep[u]] + g[v][i], s2 表示 g[u][i - 1] if(!x && !y) return 0; if(!y) return add(s2, sum[x]), mul(laz[x], s1), mul(sum[x], s1), x; if(!x) return add(s1, sum[y]), mul(laz[y], s2), mul(sum[y], s2), y; if(l == r){ ll tmp = sum[x]; add(s1, sum[y]), mul(sum[x], s1); add(sum[x], sum[y] * s2 % mod), add(s2, tmp); //注意先后顺序 return x; } pushdown(x), pushdown(y); int mid = (l + r) >> 1; ls[x] = merge(ls[x], ls[y], l, mid, s1, s2); rs[x] = merge(rs[x], rs[y], mid + 1, r, s1, s2); return pushup(x), x; } void dfs(int x, int fa){ int mxd = 0; dep[x] = dep[fa] + 1; for(int i = g.hd[x]; i; i = g.nxt[i]) mxd = max(mxd, dep[g.to[i]]); //找到最深的点 build(rt[x], 0, n, mxd); for(int i = e.hd[x]; i; i = e.nxt[i]){ int y = e.to[i]; if(y == fa) continue; dfs(y, x); ll zx = query(rt[y], 0, n, dep[x]), zxq = 0; rt[x] = merge(rt[x], rt[y], 0, n, zx, zxq); } } bool _v; int main(){ cerr << abs(&_u - &_v) / 1048576.0 << " MB\n"; // freopen("destiny4.in", "r", stdin); n = read(); for(int i = 1; i < n; ++i){ int u = read(), v = read(); e.add(u, v), e.add(v, u); } m = read(); for(int i = 1; i <= m; ++i){ int u = read(), v = read(); g.add(v, u); } dfs(1, 0); printf("%lld\n", query(rt[1], 0, n, 0)); return 0; }
本文作者:南风未起
本文链接:https://www.cnblogs.com/jiangchen4122/p/17713280.html
版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步