"动态 DP"&动态树分治 整理
刚刚学完,整理一下。
一般是对于简单的树形dp,加上丧心病狂的修改操作,并且支持在线操作的解决方法。
看洛谷的模板题
给定一棵 \(n\) 个点的树,点带点权,有 \(m\) 次操作,每次操作给定 \(x,y\),表示修改点 \(x\) 的权值为 \(y\)。你需要在每次操作之后求出这棵树的最大权独立集的权值大小。
显然,如果没有修改操作的话是可以直接线性求出答案的。我们设 f[u][1] 为选取点 u 的答案,f[u][0] 为不选取点 u 的答案,可以得到转移方程:\(f[u][0] = \sum \max(f[v][0],f[v][1]),f[u][1] = a[u] + \sum f[v][0]\) 。
很明显对于所有的修改我们都重新求一遍答案的话复杂度会爆炸。
考虑每一次修改那些位置的 dp 数组发生了变化。明显只有修改位置到根这条链上会有变化。那么假如数据随机,可以 \(mlogn\) 通过此题。可数据不可能随机对吧,当树退化成一条链的时候,这个做法一定会被卡成 \(nm\) 。那么怎么办呢?
考虑进行轻重链剖分。
我们设 g 为所有轻儿子对父亲的贡献,那么就有了 \(f[u][0] = g[u][0] + \max(f[son[u]][0],f[son[u]][1]),f[u][1] = a[u] + g[u][0] + f[son[u]][0]\) 。发现 a[u] ,g[u][0] 都至于 u 有关,那就把 a[u] 放进 g[u][0] 里。
然后得到 \(f[u][0] = g[u][0] + \max(f[son[u]][0],f[son[u]][1]),f[u][1] = g[u][0] + f[son[u]][0]\) 。
非常显然,对于一条到根的链,最多只有 log 个轻儿子,也就是只有 log 个 g 需要修改,似乎可以加速了,yeah!
但是还有一个问题,那就是对于一条重链,我又如何快速得到答案呢?
我们把前面说的转移方程口胡到矩阵上,定义一种矩阵乘法 '' ,有 ab=c 时,\(c[i][j] = \max(a[i][k]+b[k][j])\)
可以得到转移矩阵 \(\begin{vmatrix}g[i][0]&f[i][0]\\f[i][1]&-inf\end{vmatrix} \times \begin{vmatrix}f[j][0]\\f[j][1]\end{vmatrix} = \begin{vmatrix}f[i][0]\\f[i][1]\end{vmatrix}\) ,在这里 j = son[i]。
口胡后发现它满足结合律,因此可以使用线段树来进行维护。
对于修改,我们往上跳,不停的撤销原先的贡献,再将新的贡献加上去。对于询问,我们在线段树上查询根这条重链的开头和结尾,求得答案。
#include <cstdio>
#include <cstring>
#include <algorithm>
#define mid (l+r>>1)
using namespace std;
int read()
{
int a = 0,x = 1;char ch = getchar();
while(ch > '9' || ch < '0') {if(ch == '-') x = -1;ch = getchar();}
while(ch >= '0' && ch <= '9') {a = a*10 + ch-'0';ch = getchar();}
return a*x;
}
const int N=1e6+7,inf = 1e9+7;
int n,m;
int head[N],go[N],nxt[N],cnt,a[N];
void add(int u,int v)
{
go[++cnt] = v;
nxt[cnt] = head[u];
head[u] = cnt;
}
#define end End
int dfn[N],pos[N],end[N],str[N],siz[N],son[N],g[N][2],fa[N],f[N][2],dep[N];
void dfs1(int u)
{
siz[u] = 1;
for(int e = head[u];e;e = nxt[e]) {
int v = go[e];if(v == fa[u]) continue;
fa[v] = u;dfs1(v);
siz[u] += siz[v];
if(siz[v] > siz[son[u]]) son[u] = v;
}
}
struct node{
int a[2][2];
node (int x,int y) {a[0][0] = a[0][1] = x,a[1][0] = y,a[1][1] = -inf;}
node () {}
friend node operator * (node a,node b)
{
node c = (node){-inf,-inf};
for(int i = 0;i < 2;i ++)
for(int j = 0;j < 2;j ++)
for(int k = 0;k < 2;k ++)
c.a[i][j] = max(c.a[i][j],a.a[i][k] + b.a[k][j]);
return c;
}
}val[N],tre[N];
void dfs2(int u,int h)
{
str[u] = h,end[h] = u;dep[u] = dep[fa[u]] + 1;
dfn[u] = ++cnt,pos[cnt] = u;
if(son[u]) dfs2(son[u],h);
f[u][1] = f[son[u]][0] + a[u],f[u][0] = max(f[son[u]][0],f[son[u]][1]);
for(int e = head[u];e;e = nxt[e]) {
int v = go[e];if(v == fa[u] || v == son[u]) continue;
dfs2(v,v);f[u][1] += f[v][0],f[u][0] += max(f[v][0],f[v][1]);
}
g[u][0] = f[u][0] - max(f[son[u]][0],f[son[u]][1]);
g[u][1] = f[u][1] - f[son[u]][0];
val[u] = (node){g[u][0],g[u][1]};
}
void build(int root,int l,int r)
{
if(l == r) {tre[root] = val[pos[l]];return ;}
build(root<<1,l,mid);build(root<<1|1,mid+1,r);
tre[root] = tre[root<<1] * tre[root<<1|1];
}
void update(int root,int l,int r,int p)
{
if(l == r && l == p) {tre[root] = val[pos[p]];return ;}
if(p <= mid) update(root<<1,l,mid,p);
else update(root<<1|1,mid+1,r,p);
tre[root] = tre[root<<1] * tre[root<<1|1];
}
node query(int root,int l,int r,int ql,int qr)
{
if(l >= ql && r <= qr) return tre[root];
if(qr <= mid) return query(root<<1,l,mid,ql,qr);
else if(ql > mid) return query(root<<1|1,mid+1,r,ql,qr);
else return query(root<<1,l,mid,ql,qr) * query(root<<1|1,mid+1,r,ql,qr);
}
void solve(int p,int x)
{
val[p].a[1][0] += x-a[p];
a[p] = x;node tmp1,tmp2;
while(p) {
tmp1 = query(1,1,n,dfn[str[p]],dfn[end[str[p]]]);
update(1,1,n,dfn[p]);
tmp2 = query(1,1,n,dfn[str[p]],dfn[end[str[p]]]);
p = fa[str[p]];
if(!p) break;
val[p].a[0][0] += max(tmp2.a[1][0],tmp2.a[0][0]) - max(tmp1.a[1][0],tmp1.a[0][0]);
val[p].a[0][1] = val[p].a[0][0];
val[p].a[1][0] += tmp2.a[0][0] - tmp1.a[0][0];
}
}
int main()
{
// freopen("random.in","r",stdin);
// freopen("sol.out","w",stdout);
n = read(),m = read();
for(int i = 1;i <= n;i ++) a[i] = read();
for(int i = 1;i < n;i ++) {
int u = read(),v = read();
add(u,v);add(v,u);
}
cnt = 0;dfs1(1),dfs2(1,1);
build(1,1,n);
for(int i = 1;i <= m;i ++) {
int p = read(),x = read();
solve(p,x);node tmp = query(1,1,n,dfn[1],dfn[end[1]]);
printf("%d\n",max(tmp.a[0][0],tmp.a[1][0]));
}
}