动态 dp
首先看到以下这个问题:
- 有 \(N\) 个饼干,第 \(i\) 个饼干的美味值为 \(a_i\),你可以吃任意不相邻的一些饼干。还有 \(Q\) 次询问,每次询问将修改一个饼干的美味值,问美味值之和最大能是多少。
我们尝试通过线段树的思想来解决,但这时又会有一个新的问题:合并左右儿子的信息。
这里我们介绍一种使用矩阵乘法的做法。我们让线段树维护矩阵的 dp 数组,这里的矩阵表示的是初始状态与当前状态的关系。
我们使用 Max-plus 矩阵,合并时就只许把两边合并起来即可。
代码
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const int MAXN = 200001;
const ll INF = (ll)(1e18) + 1;
struct Max_plus_Matrix {
int n, m;
ll mat[3][3];
void Clear(int a, int b) {
n = a, m = b;
for(int i = 1; i <= n; ++i) {
for(int j = 1; j <= m; ++j) {
mat[i][j] = -INF;
}
}
}
Max_plus_Matrix operator*(const Max_plus_Matrix &b) {
Max_plus_Matrix ret;
ret.Clear(n, b.m);
for(int i = 1; i <= n; ++i) {
for(int j = 1; j <= b.m; ++j) {
for(int k = 1; k <= b.n; ++k) {
ret.mat[i][j] = max(ret.mat[i][j], mat[i][k] + b.mat[k][j]);
}
}
}
return ret;
}
Max_plus_Matrix operator*=(const Max_plus_Matrix &b) {
return *this = *this * b;
}
};
struct Segment_Tree {
int l[MAXN << 2], r[MAXN << 2], a[MAXN];
Max_plus_Matrix mat[MAXN << 2];
void build(int u, int s, int t) {
l[u] = s, r[u] = t;
if(s == t) {
mat[u].n = mat[u].m = 2;
mat[u].mat[1][1] = 0, mat[u].mat[1][2] = 0, mat[u].mat[2][1] = a[s], mat[u].mat[2][2] = -INF;
return;
}
int mid = (s + t) >> 1;
build(u << 1, s, mid), build((u << 1) | 1, mid + 1, t);
mat[u] = mat[u << 1] * mat[(u << 1) | 1];
}
void update(int u, int p, int x) {
if(l[u] == r[u]) {
mat[u].mat[2][1] = x;
return;
}
if(p <= r[u << 1]) {
update(u << 1, p, x);
}else {
update((u << 1) | 1, p, x);
}
mat[u] = mat[u << 1] * mat[(u << 1) | 1];
}
ll Getans() {
return mat[1].mat[2][1];
}
}tr;
int n, q;
int main() {
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
cin >> n;
for(int i = 1; i <= n; ++i) {
cin >> tr.a[i];
}
tr.build(1, 1, n);
cin >> q;
for(int i = 1, p, x; i <= q; ++i) {
cin >> p >> x;
tr.update(1, p, x);
cout << tr.Getans() << "\n";
}
return 0;
}
树上动态 dp
给定一颗树,每个点都有一个点权,你可以选择任意一些不相邻的点。还有 \(Q\) 次修改,每次将修改一个点的点权,对于每次修改求选择的点权之和的最大值。
由于树上问题很难处理,所以考虑树链剖分。
对于每一条链,可以得到其单独的矩阵。现在我们考虑把其他分支的结果累加进来。比如以下这个树:
这里一个蓝色的圈表示一条链。这里根节点需要加上分支的贡献。
每次修改一个值,都需要把它到根节点中所有链头的父亲进行更新,使用线段树维护。
代码
#include<bits/stdc++.h>
using namespace std;
const int MAXN = 100001, INF = int(1e7) + 1;
struct Max_plus_Matrix {
int n, m, mat[3][3];
void Clear(int a, int b) {
n = a, m = b;
for(int i = 1; i <= n; ++i) {
for(int j = 1; j <= m; ++j) {
mat[i][j] = -INF;
}
}
}
void Set(int a, int b) {
n = a, m = b;
for(int i = 1; i <= n; ++i) {
for(int j = 1; j <= m; ++j) {
mat[i][j] = (i == j ? 0 : -INF);
}
}
}
Max_plus_Matrix operator*(const Max_plus_Matrix &b) {
Max_plus_Matrix ret;
ret.Clear(b.n, m);
for(int i = 1; i <= b.n; ++i) {
for(int j = 1; j <= m; ++j) {
for(int k = 1; k <= n; ++k) {
ret.mat[i][j] = max(ret.mat[i][j], mat[k][j] + b.mat[i][k]);
}
}
}
return ret;
}
Max_plus_Matrix operator*=(const Max_plus_Matrix &b) {
return *this = *this * b;
}
};
struct Segment_Tree {
int l[MAXN << 2], r[MAXN << 2], dfn[MAXN], a[MAXN];
Max_plus_Matrix mat[MAXN << 2];
void build(int u, int s, int t) {
l[u] = s, r[u] = t;
if(s == t) {
return;
}
int mid = (s + t) >> 1;
build(u << 1, s, mid), build((u << 1) | 1, mid + 1, t);
}
void update(int u, int p, Max_plus_Matrix x) {
if(l[u] == r[u]) {
mat[u] = x;
return;
}
if(p <= r[u << 1]) {
update(u << 1, p, x);
}else {
update((u << 1) | 1, p, x);
}
mat[u] = mat[(u << 1) | 1] * mat[u << 1];
}
Max_plus_Matrix Getmat(int u, int s, int t) {
if(s > t) {
Max_plus_Matrix x;
x.Set(2, 2);
return x;
}
if(l[u] >= s && r[u] <= t) {
return mat[u];
}
Max_plus_Matrix x;
x.Set(2, 2);
if(t >= l[(u << 1) | 1]) {
x *= Getmat((u << 1) | 1, s, t);
}
if(s <= r[u << 1]) {
x *= Getmat(u << 1, s, t);
}
return x;
}
}tr;
int n, m, sz[MAXN], dfn[MAXN], tot, top[MAXN], f[MAXN], son[MAXN], tail[MAXN], dp[MAXN][2];
vector<int> e[MAXN];
void dfs(int u, int fa) {
sz[u] = 1, f[u] = fa;
for(int v : e[u]) {
if(v != fa) {
dfs(v, u);
sz[u] += sz[v];
if(sz[v] > sz[son[u]]) {
son[u] = v;
}
}
}
}
void DFS(int u, int fa) {
dfn[u] = ++tot, tr.dfn[tot] = u;
if(son[u]) {
top[son[u]] = top[u], DFS(son[u], u);
}else {
tail[top[u]] = u;
}
for(int v : e[u]) {
if(v != fa && v != son[u]) {
top[v] = v, DFS(v, u);
Max_plus_Matrix x = tr.Getmat(1, dfn[v], dfn[tail[v]]);
dp[u][0] += max(x.mat[1][1], x.mat[2][1]);
dp[u][1] += x.mat[1][1];
}
}
Max_plus_Matrix x;
x.n = x.m = 2;
x.mat[1][1] = dp[u][0], x.mat[1][2] = dp[u][0], x.mat[2][1] = dp[u][1] + tr.a[u], x.mat[2][2] = -INF;
tr.update(1, dfn[u], x);
}
void update(int u) {
for(; u; ) {
Max_plus_Matrix x = tr.Getmat(1, dfn[top[u]], dfn[tail[top[u]]]);
dp[f[top[u]]][0] -= max(x.mat[1][1], x.mat[2][1]);
dp[f[top[u]]][1] -= x.mat[1][1];
x.n = x.m = 2, x.mat[1][1] = x.mat[1][2] = dp[u][0];
x.mat[2][1] = dp[u][1] + tr.a[u], x.mat[2][2] = -INF;
tr.update(1, dfn[u], x);
x = tr.Getmat(1, dfn[top[u]], dfn[tail[top[u]]]);
dp[f[top[u]]][0] += max(x.mat[1][1], x.mat[2][1]);
dp[f[top[u]]][1] += x.mat[1][1];
u = f[top[u]];
}
}
int Getans() {
Max_plus_Matrix x = tr.Getmat(1, 1, dfn[tail[1]]);
return max(x.mat[1][1], x.mat[2][1]);
}
int main() {
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
cin >> n >> m;
for(int i = 1; i <= n; ++i) {
cin >> tr.a[i];
tr.a[i] = max(0, tr.a[i]);
}
for(int i = 1, u, v; i < n; ++i) {
cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
tr.build(1, 1, n);
dfs(1, 0);
top[1] = 1;
DFS(1, 0);
for(int i = 1, u, v; i <= m; ++i) {
cin >> u >> v;
v = max(v, 0);
tr.a[u] = v;
update(u);
cout << Getans() << "\n";
}
return 0;
}