P5024 保卫王国
题目描述
Z 国有 n 座城市,(n−1) 条双向道路,每条双向道路连接两座城市,且任意两座城市都能通过若干条道路相互到达。
Z 国的国防部长小 Z 要在城市中驻扎军队。驻扎军队需要满足如下几个条件:
- 一座城市可以驻扎一支军队,也可以不驻扎军队。
- 由道路直接连接的两座城市中至少要有一座城市驻扎军队。
- 在城市里驻扎军队会产生花费,在编号为 iii 的城市中驻扎军队的花费是 pi。
小 Z 很快就规划出了一种驻扎军队的方案,使总花费最小。但是国王又给小 Z 提出了 m 个要求,每个要求规定了其中两座城市是否驻扎军队。小 Z 需要针对每个要求逐一给出回答。具体而言,如果国王提出的第 j 个要求能够满足上述驻扎条件(不需要考虑第 j 个要求之外的其它要求),则需要给出在此要求前提下驻扎军队的最小开销。如果国王提出的第 j 个要求无法满足,则需要输出 −1。现在请你来帮助小 Z。
输入格式
第一行有两个整数和一个字符串,依次表示城市数 n,要求数 m 和数据类型 type。
第二行有 n 个整数,第 i 个整数表示编号 i 的城市中驻扎军队的花费 pi。
接下来 (n−1) 行,每行两个整数 u,v,表示有一条 u 到 v 的双向道路。
接下来 m 行,每行四个整数 a,x,b,y,表示一个要求是在城市 a 驻扎 x 支军队,在城市 b 驻扎 y 支军队。其中,x,y 的取值只有 0 或 1:
- 若 x 为 0,表示城市 a 不得驻扎军队。
- 若 x 为 1,表示城市 a 必须驻扎军队。
- 若 y 为 0,表示城市 b 不得驻扎军队。
- 若 y 为 1,表示城市 b 必须驻扎军队。
输入文件中每一行相邻的两个数据之间均用一个空格分隔。
输出格式
输出共 m 行,每行包含一个个整数,第 j 行表示在满足国王第 j 个要求时的最小开销, 如果无法满足国王的第 j 个要求,则该行输出 −1。
输入输出样例
输入 #1
5 3 C3
2 4 1 3 9
1 5
5 2
5 3
3 4
1 0 3 0
2 1 3 1
1 0 5 0
输出 #1
12
7
-1
动态DP题。
我们首先可以推出DP方程:\(dpi,1=∑min(dpj,0,dpj,1)\)
\(dpi, 0 = \sum_{j} dpj, 1 + a_i\)
对于每一次修改,我们要\(O(n)\)的改一条链,很不妙,所以用树链剖分,\(O(logn)\)的修改一条链。
我们用\(f_{i, 0/1}\)表示以\(i\)为根的这颗树内,轻儿子的ans。
所以DP方程还可以写成:\(dp_{i, 0} = f_{j, 0} + dp_{son, 1}\)
\(dp_{i, 1} = f_{i, 1} + min(dp_{son, 0}, dp_{son, 1})\)
这里\(son\)是轻儿子。
转移矩阵:
这里\(son\)是重儿子。
修改一下矩阵乘法规则,把加法改为取\(min\)
#include <bits/stdc++.h>
#pragma GCC optimize(2)
#define ls(o) (o << 1)
#define rs(o) (o << 1 | 1)
#define mid ((l + r) >> 1)
#define int long long
using namespace std;
inline long long read() {
long long s = 0, f = 1; char ch;
while(!isdigit(ch = getchar())) (ch == '-') && (f = -f);
for(s = ch ^ 48;isdigit(ch = getchar()); s = (s << 1) + (s << 3) + (ch ^ 48));
return s * f;
}
const int N = 1e5 + 5, inf = 1e11;
int n, m, cnt, tot;
string C;
int p[N], fa[N], id[N], siz[N], hav[N], end[N], top[N], dfn[N], head[N];
int f[N][2], dp[N][2];
struct edge { int to, nxt; } e[N << 1];
struct mat {
int v[2][2];
mat() {
for(int i = 0;i <= 1; i++)
for(int j = 0; j <= 1; j++)
v[i][j] = inf;
}
} t[N << 2], nod[N];
mat operator * (const mat &a, const mat &b) {
mat c;
for(int i = 0;i <= 1; i++)
for(int j = 0;j <= 1; j++)
for(int k = 0;k <= 1; k++)
c.v[i][j] = min(c.v[i][j], a.v[i][k] + b.v[k][j]);
return c;
}
void add(int x, int y) {
e[++cnt].nxt = head[x]; head[x] = cnt; e[cnt].to = y;
}
void get_tree(int x, int Fa) {
siz[x] = 1; fa[x] = Fa;
for(int i = head[x]; i; i = e[i].nxt) {
int y = e[i].to; if(y == Fa) continue;
get_tree(y, x);
siz[x] += siz[y];
if(siz[y] > siz[hav[x]]) hav[x] = y;
}
}
void get_top(int x, int topic) {
dfn[x] = ++tot; id[tot] = x; top[x] = topic;
if(hav[x] == 0) {
end[topic] = tot; return ;
}
get_top(hav[x], topic);
for(int i = head[x]; i; i = e[i].nxt) {
int y = e[i].to; if(y == fa[x] || y == hav[x]) continue;
get_top(y, y);
}
}
void dfs(int x) {
dp[x][1] = f[x][1] = p[x];
for(int i = head[x]; i; i = e[i].nxt) {
int y = e[i].to; if(y == fa[x]) continue;
dfs(y);
dp[x][0] += dp[y][1];
dp[x][1] += min(dp[y][0], dp[y][1]);
if(y != hav[x]) {
f[x][0] += dp[y][1];
f[x][1] += min(dp[y][0], dp[y][1]);
}
}
}
void up(int o) {
t[o] = t[ls(o)] * t[rs(o)];
}
void build(int o, int l, int r) {
if(l == r) {
int x = id[l];
t[o].v[0][0] = t[o].v[0][1] = f[x][1];
t[o].v[1][0] = f[x][0]; nod[l] = t[o];
return ;
}
build(ls(o), l, mid); build(rs(o), mid + 1, r);
up(o);
}
void change(int o, int l, int r, int x) {
if(l == r) { t[o] = nod[l]; return ; }
if(x <= mid) change(ls(o), l, mid, x);
if(x > mid) change(rs(o), mid + 1, r, x);
up(o);
}
mat get_ans(int o, int l, int r, int x, int y) {
if(x <= l && y >= r) return t[o];
if(y <= mid) return get_ans(ls(o), l, mid, x, y);
if(x > mid) return get_ans(rs(o), mid + 1, r, x, y);
return get_ans(ls(o), l, mid, x, y) * get_ans(rs(o), mid + 1, r, x, y);
}
void up_tree(int x, int y) {
nod[dfn[x]].v[0][0] = nod[dfn[x]].v[0][1] += y - p[x];
p[x] = y;
while(x) {
mat Old = get_ans(1, 1, n, dfn[top[x]], end[top[x]]);
change(1, 1, n, dfn[x]);
mat New = get_ans(1, 1, n, dfn[top[x]], end[top[x]]);
int tmp = dfn[fa[top[x]]];
nod[tmp].v[0][0] = nod[tmp].v[0][1] += min(New.v[0][0], New.v[1][0]) - min(Old.v[0][0], Old.v[1][0]);
nod[tmp].v[1][0] += New.v[0][0] - Old.v[0][0];
x = fa[top[x]];
}
}
signed main() {
n = read(); m = read(); cin >> C;
for(int i = 1;i <= n; i++) p[i] = read();
for(int i = 1, x, y;i <= n - 1; i++) {
x = read(); y = read();
add(x, y); add(y, x);
}
get_tree(1, 0); get_top(1, 1); dfs(1); build(1, 1, n);
for(int i = 1, a, x, b, y;i <= m; i++) {
a = read(); x = read(); b = read(); y = read();
int tmp1 = p[a], tmp2 = p[b];
if(x == 0 && y == 0 && (fa[a] == b || fa[b] == a)) { printf("-1\n"); continue; }
up_tree(a, x ? tmp1 - inf : tmp1 + inf);
up_tree(b, y ? tmp2 - inf : tmp1 + inf);
mat ans = get_ans(1, 1, n, dfn[1], end[1]);
int res = min(ans.v[0][0], ans.v[1][0]);
if(x == 1) res += inf;
if(y == 1) res += inf;
up_tree(a, tmp1); up_tree(b, tmp2);
printf("%lld\n", res);
}
return 0;
}
(不开\(O_2\)过不去qwq)