P6021 洪水
upd on 2021.5.6:更新了一些描述 和 公式。
半年之后被迫营业讲 ddp 的时候才发现原来的题解出了亿点bug。
题意:给一棵树,每个点有点权 \(a_i\),每次询问:在某个子树内以点权为代价删除(堵上)一些点使得根与子树内所有叶子不连通的最小代价;带单点修改。
不带修情况:
设 \(f_i\) 表示把以 \(i\) 为根的子树完全堵上的答案。
转移简洁并且单点修改使我们想到使用 ddp 来维护。
分离轻重儿子(这里的 \(son(u)\) 表示 \(u\) 的重儿子):
其中 \(g_u\) 是轻儿子的贡献
把转移方程写成矩阵
重定义广义矩阵乘法:
构造矩阵:
如果矩阵是一维的
那么 \(x=g_u,y=a_u-f_{son(u)}\)
发现左矩阵做了一个重儿子的东西,不能做ddp。
考虑再加一维:
因为 \(a_u=a_u+0\) ,考虑直接把 \(p,q\) 设成 \(0\) , 那么根据转移方程,\(x=g_u,y=a_u\) ,这样 \(f_u\) 已经被正确表示了。
但是底下的 \(q\) 在矩乘之后不一定是 \(0\) ,考虑通过 \(z,w\) 来维护 \(q\) 。
直接展开 \(q\) ,\(q=\min(z+f_{son(u)},w)\) (\(p=0\) 就不写了)
\(f_{son(u)}\) 是非负的,所以 \(w=0,z\ge 0\) 即可
矩阵构造完毕!
封死一颗子树的代价就是 \(f_u\)。
这里要提一个小细节,就是叶子节点没有轻儿子的时候 \(g\) 怎么办。
我想到了两种处理方法:
一种处理方法是把 \(g_u\) 设为 \(a_u\),因为叶子节点在ddp的时候要满足 \(f_u=g_u\)。
还有一种方法就是把 \(g_u\) 设成 \(+\infty\),直接禁止从“封死所有子树”这种方法的转移。并且矩阵左上角不和自己的 \(a\) 取 \(\min\),只维护轻子树的 dp值 和,调用 dp 值的时候再和自己的 \(a\) 取 \(\min\)。
第一种在树剖的时候比较好搞;如果用 LCT 维护我暂时没想到什么好的维护方法所以用了第二种。
修改
把 \(x\) 加 \(v\)
直接影响的就是这个节点的 \(a_u\)。
但是叶子节点还得同时更改 \(g_u\),千万别忘。
至于轻边父亲的修改,减去原贡献加上新贡献就好了。
LCT 同理,不过在修改一个节点矩阵的时候要先 access
再 splay
,这时候修改它对于任何节点都是没有影响的。
查询
树剖直接用线段树把这个点到重链底端的矩阵全部乘起来就好了。
LCT 比较特殊。要把这个节点在 原树 上的父亲 access
一下,再 splay
这个节点,这样子这个节点的信息就是这颗子树的信息。
我access在lct上的父亲调了一晚上(((
矩乘只是 \(2\times2\times2\) ,建议手动暴力展开,可以快非常多。
因为树剖代码是在之前的代码上改的,怕有些地方与描述不同,因此重写了一份 LCT 的代码
树剖版代码
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef double db;
#define pb(x) push_back(x)
#define mkp(x,y) make_pair(x,y)
//#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
//char buf[1<<21],*p1=buf,*p2=buf;
inline int read() {
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)) {if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch))x=x*10+(ch^48),ch=getchar();
return x*f;
}
int rdc(){
char ch=getchar();
while(ch!='Q'&&ch!='C')ch=getchar();
return ch=='Q';
}
const int N=200005;
const LL inf=1e14;
const int T=N<<2;
int n,dp[N];
LL a[N];
int head[N],num_edge;
int dfn[N],rev[N],tmr,fa[N],siz[N],son[N],top[N],ed[N];
struct edge{
int nxt,to;
}e[N<<1];
void addedge(int fr,int to){
++num_edge;
e[num_edge].nxt=head[fr];
e[num_edge].to=to;
head[fr]=num_edge;
}
struct Matrix {
LL a[2][2];
Matrix(){a[0][0]=a[0][1]=a[1][0]=a[1][1]=inf;}
LL*operator[](const int&k){return a[k];}
Matrix operator * (const Matrix&b){
Matrix res;
res.a[0][0] = min(a[0][0]+b.a[0][0],a[0][1]+b.a[1][0]);
res.a[0][1] = min(a[0][0]+b.a[0][1],a[0][1]+b.a[1][1]);
res.a[1][0] = min(a[1][0]+b.a[0][0],a[1][1]+b.a[1][0]);
res.a[1][1] = min(a[1][0]+b.a[0][1],a[1][1]+b.a[1][1]);
return res;
}
void print(){
printf("%lld %lld\n%lld %lld\n\n",a[0][0],a[0][1],a[1][0],a[1][1]);
}
}mat[N],val[T];
void dfs1(int u,int ft){
if(!e[head[u]].nxt)return dp[u]=a[u],siz[u]=1,void();
LL sum=0;siz[u]=1;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;if(v==ft)continue;
fa[v]=u,dfs1(v,u),sum+=dp[v],siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
dp[u]=min(sum,a[u]);
}
void dfs2(int u,int tp){
top[u]=tp,dfn[u]=++tmr,rev[tmr]=u,ed[tp]=tmr;
mat[u][0][0]=0,mat[u][0][1]=a[u],
mat[u][1][0]=0,mat[u][1][1]=0;
if(!son[u])return mat[u][0][0]=a[u],void();
dfs2(son[u],tp);
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa[u]||v==son[u])continue;
dfs2(v,v),mat[u][0][0]+=dp[v];
}
}
#define lc (p<<1)
#define rc (p<<1|1)
void pushup(int p){val[p]=val[lc]*val[rc];}
void build(int l,int r,int p=1){
if(l==r)return val[p]=mat[rev[l]],void();
int mid=(l+r)>>1;
build(l,mid,lc),build(mid+1,r,rc);
pushup(p);
}
Matrix query(int ql,int qr,int l=1,int r=n,int p=1){
if(ql<=l&&r<=qr)return val[p];
int mid=(l+r)>>1;
if(qr<=mid)return query(ql,qr,l,mid,lc);
if(mid<ql)return query(ql,qr,mid+1,r,rc);
return query(ql,qr,l,mid,lc)*query(ql,qr,mid+1,r,rc);
}
void change(int pos,int l=1,int r=n,int p=1){
if(l==r)return val[p]=mat[rev[pos]],void();
int mid=(l+r)>>1;
if(pos<=mid)change(pos,l,mid,lc);
else change(pos,mid+1,r,rc);
pushup(p);
}
void update(int x,int v){
mat[x][0][1]+=v,a[x]+=v;
if(siz[x]==1)mat[x][0][0]+=v;
while(x){
Matrix lst=query(dfn[top[x]],ed[top[x]]);
change(dfn[x]);
Matrix now=query(dfn[top[x]],ed[top[x]]);
x=fa[top[x]];
mat[x][0][0]+=now[0][0]-lst[0][0];
}
}
signed main(){
n=read();
for(int i=1;i<=n;++i)a[i]=read();
for(int i=1;i<n;++i){
int x=read(),y=read();
addedge(x,y),addedge(y,x);
}
dfs1(1,0),dfs2(1,1),build(1,n);
for(int m=read();m;--m){
int opt=rdc(),x=read();
if(opt){
Matrix t=query(dfn[x],ed[top[x]]);
printf("%lld\n",t[0][0]);
}
else update(x,read());
}
return 0;
}
LCT 版代码
#include <bits/stdc++.h>
using namespace std;
typedef double db;
typedef long long LL;
#define fi first
#define se second
#define pb push_back
#define mkp make_pair
#define sz(v) (int)(v).size()
#define rep(i, x, y) for(int i = x, i##end = y; i <= i##end; ++i)
#define per(i, x, y) for(int i = x, i##end = y; i >= i##end; --i)
template<class T> inline bool ckmax(T &x, T y) { return x < y ? x = y, 1 : 0; }
template<class T> inline bool ckmin(T &x, T y) { return x > y ? x = y, 1 : 0; }
inline int read() {
int x = 0, f = 1; char ch = getchar();
while(!isdigit(ch)) { if(ch == '-') f = 0; ch = getchar(); }
while(isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
return f ? x : -x;
}
inline int rdch() {
char ch = getchar();
while(ch != 'Q' && ch != 'C') ch = getchar();
return ch == 'Q';
}
const int N = 200005;
const LL inf = 1e14;
int n, m, lef[N], tfa[N];
LL a[N], dp[N];
vector<int> e[N];
int fa[N], ch[N][2];
struct Matrix {
LL a[2][2];
Matrix() { a[0][0] = a[0][1] = a[1][0] = a[1][1] = inf; }
inline LL* operator [](const int &k) { return a[k]; }
inline Matrix operator * (const Matrix &t) const {
Matrix res;
res.a[0][0] = min(a[0][0] + t.a[0][0], a[0][1] + t.a[1][0]);
res.a[0][1] = min(a[0][0] + t.a[0][1], a[0][1] + t.a[1][1]);
res.a[1][0] = min(a[1][0] + t.a[0][0], a[1][1] + t.a[1][0]);
res.a[1][1] = min(a[1][0] + t.a[0][1], a[1][1] + t.a[1][1]);
return res;
}
} val[N], sum[N];
inline bool nroot(int x) { return ch[fa[x]][0] == x || ch[fa[x]][1] == x; }
inline void pushup(int x) {
sum[x] = val[x];
if(ch[x][0]) sum[x] = sum[ch[x][0]] * sum[x];
if(ch[x][1]) sum[x] = sum[x] * sum[ch[x][1]];
}
inline void rotate(int x) {
int y = fa[x], z = fa[y], k = ch[y][1] == x, w = ch[x][!k];
if(nroot(y)) ch[z][ch[z][1] == y] = x;
ch[x][!k] = y, ch[y][k] = w;
fa[w] = y, fa[y] = x, fa[x] = z;
pushup(y);
}
inline void splay(int x) {
while(nroot(x)) {
int y = fa[x], z = fa[y];
if(nroot(y)) rotate((ch[z][1] == y) ^ (ch[y][1] == x) ? x : y);
rotate(x);
}
pushup(x);
}
inline void access(int x) {
for(int y = 0; x; x = fa[y = x]) {
splay(x);
if(y) val[x][0][0] -= min(sum[y][0][0], sum[y][0][1]);
if(ch[x][1]) val[x][0][0] += min(sum[ch[x][1]][0][0], sum[ch[x][1]][0][1]);
ch[x][1] = y, pushup(x);
}
}
void dfs(int u, int ft) {
lef[u] = 1;
for(int v : e[u]) if(v != ft)
tfa[v] = fa[v] = u, dfs(v, u), dp[u] += dp[v], lef[u] = 0;
if(lef[u]) dp[u] = a[u];
val[u][0][0] = lef[u] ? inf : dp[u], val[u][0][1] = a[u];
val[u][1][0] = val[u][1][1] = 0;
pushup(u);
ckmin(dp[u], a[u]);
}
signed main() {
n = read();
rep(i, 1, n) a[i] = read();
rep(i, 2, n) {
int x = read(), y = read();
e[x].pb(y), e[y].pb(x);
}
dfs(1, 0);
for(m = read(); m--; ) {
int op = rdch(), x = read();
if(op) {
if(tfa[x]) access(tfa[x]);
splay(x), printf("%lld\n", min(sum[x][0][0], sum[x][0][1]));
} else {
int y = read();
access(x), splay(x);
val[x][0][1] += y;
pushup(x);
}
}
}