动态DP(动态树分治)
Link
首先我们有一个静态的dp。
设\(f_{u,0/1}\)表示只考虑\(u\)的子树,\(u\)不选/选的答案。
那么很显然有:
\[\begin{aligned}
f_{u,0}&=\sum\limits_{v\in son_u}\max(f_{v,0},f_{v,1})\\
f_{u,1}&=w_u+\sum\limits_{v\in son_u}f_{v,0}
\end{aligned}
\]
考虑利用重链剖分来进行这个过程,设\(h_u\)表示\(u\)的重儿子,\(g_{u,0/1}\)表示\(f_{u,0/1}\)在不考虑\(h_u\)子树情况下的答案。
那么有:
\[\begin{aligned}
g_{u,0}&=\sum\limits_{v\in son_u\wedge v\ne h_u}\max(f_{v,0},f_{v,1})\\
g_{u,1}&=w_u+\sum\limits_{v\in son_u\wedge v\ne h_u}f_{v,0}\\
f_{u,0}&=\max(f_{h_u,0},f_{h_u,1})+g_{u,0}\\
f_{u,1}&=f_{h_u,0}+g_{u,1}
\end{aligned}
\]
对于一条重链,实际上我们只关心\(f_{top}\)。
假如我们已经求出了链上所有点的\(g\),那么我们可以做一个序列dp得到\(f_{top}\)。
实际上我们可以把这个序列dp的转移写成矩阵乘法的形式。
定义\(C=AB\)为满足\(C_{i,j}=\max\limits_k(A_{i,k}+B_{k,j})\)的矩阵,那么有:
\[\begin{pmatrix}f_{h_u,0}&f_{h_u,1}\end{pmatrix}\begin{pmatrix}g_{u,0}&g_{u,1}\\g_{u,0}&-\infty\end{pmatrix}=\begin{pmatrix}f_{u,0}&f_{u,1}\end{pmatrix}
\]
注意到新定义的矩阵乘法仍然具有结合律,因此我们可以用线段树维护每条重链上的矩阵的区间积。为了方便我们在线段树外同时记录每个点的转移矩阵。
这样做的时间复杂度为\(O(n\log n+q\log^2n)\)。
#include<cctype>
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
const int N=100007,inf=1e9;
char ibuf[1<<23|1],*iS=ibuf;
int n,m,val[N],fa[N],size[N],son[N],top[N],ch[N],dfn[N],id[N],f[N][2];
std::vector<int>e[N];
struct matrix{int a[2][2];int*operator[](int x){return a[x];}}t[4*N],a[N];
matrix operator*(matrix a,matrix b)
{
matrix c;
c[0][0]=std::max(a[0][0]+b[0][0],a[0][1]+b[1][0]),c[0][1]=std::max(a[0][0]+b[0][1],a[0][1]+b[1][1]);
c[1][0]=std::max(a[1][0]+b[0][0],a[1][1]+b[1][0]),c[1][1]=std::max(a[1][0]+b[0][1],a[1][1]+b[1][1]);
return c;
}
int read(){int x=0,f=1;while(isspace(*iS))++iS;if(*iS=='-')++iS,f=-1;while(isdigit(*iS))(x*=10)+=*iS++&15;return f*x;}
void dfs1(int u)
{
size[u]=1;
for(int v:e[u]) if(v^fa[u]) if(fa[v]=u,dfs1(v),size[u]+=size[v],size[v]>size[son[u]]) son[u]=v;
}
void dfs2(int u,int tp)
{
static int tim;id[dfn[u]=++tim]=ch[u]=u,top[u]=tp;
if(son[u]) dfs2(son[u],tp),ch[u]=ch[son[u]];
for(int v:e[u]) if(v^fa[u]&&v^son[u]) dfs2(v,v);
}
void dfs3(int u)
{
f[u][1]=val[u];
for(int v:e[u]) if(v^fa[u]) dfs3(v),f[u][0]+=std::max(f[v][0],f[v][1]),f[u][1]+=f[v][0];
}
matrix get(int u)
{
int g0=0,g1=val[u];
for(int v:e[u]) if(v^fa[u]&&v^son[u]) g0+=std::max(f[v][0],f[v][1]),g1+=f[v][0];
return {g0,g0,g1,-inf};
}
#define ls p<<1
#define rs p<<1|1
#define mid ((l+r)/2)
void pushup(int p){t[p]=t[ls]*t[rs];}
void build(int p,int l,int r)
{
if(l==r) return a[l]=t[p]=get(id[l]),void();
build(ls,l,mid),build(rs,mid+1,r),pushup(p);
}
void update(int p,int l,int r,int x)
{
if(l==r) return t[p]=a[l],void();
x<=mid? update(ls,l,mid,x):update(rs,mid+1,r,x),pushup(p);
}
matrix query(int p,int l,int r,int L,int R)
{
if(L<=l&&r<=R) return t[p];
if(R<=mid) return query(ls,l,mid,L,R);
if(L>mid) return query(rs,mid+1,r,L,R);
return query(ls,l,mid,L,R)*query(rs,mid+1,r,L,R);
}
#undef ls
#undef rs
#undef mid
void modify(int u,int w)
{
a[dfn[u]][1][0]+=w-val[u],val[u]=w;
while(u)
{
matrix p=query(1,1,n,dfn[top[u]],dfn[ch[u]]);
update(1,1,n,dfn[u]);
matrix q=query(1,1,n,dfn[top[u]],dfn[ch[u]]);
if(!(u=fa[top[u]]))break;
int x=dfn[u],g0=p[0][0],g1=p[1][0],f0=q[0][0],f1=q[1][0];
a[x][0][0]=a[x][0][1]=a[x][0][0]+std::max(f0,f1)-std::max(g0,g1),a[x][1][0]=a[x][1][0]+f0-g0;
}
}
void work()
{
int u=read(),w=read();modify(u,w);
matrix ans=query(1,1,n,dfn[1],dfn[ch[1]]);
printf("%d\n",std::max(ans[0][0],ans[1][0]));
}
int main()
{
fread(ibuf,1,1<<23,stdin);
n=read(),m=read();
for(int i=1;i<=n;++i) val[i]=read();
for(int i=1,u,v;i<n;++i) u=read(),v=read(),e[u].push_back(v),e[v].push_back(u);
dfs1(1),dfs2(1,1),dfs3(1);
build(1,1,n);
for(int i=1;i<=m;++i) work();
}
还有一个叫做全局平衡二叉树的东西。
类似于LCT,大致思想还是用Splay维护每一条重链。
注意到树是静态的,因此并不需要支持rotate等改变数的形态的操作,因此常数会小很多。
对于每条重链,为了建出较为平衡的bst,我们按轻儿子\(size\)之和的加权重心递归建树。
#include<cctype>
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
const int N=1000007,inf=1e9;
char ibuf[1<<27|1],*iS=ibuf;
int n,q,val[N],son[N],sz[N];std::vector<int>e[N];
int read(){int x=0,f=1;while(isspace(*iS))++iS;if(*iS=='-')++iS,f=-1;while(isdigit(*iS))(x*=10)+=*iS++&15;return f*x;}
struct matrix
{
int a[2][2];
matrix(){a[0][0]=a[0][1]=a[1][0]=a[1][1]=-inf;}
int*operator[](int x){return a[x];}
int cal(){return std::max(std::max(a[0][0],a[0][1]),std::max(a[1][0],a[1][1]));}
};
matrix operator*(matrix a,matrix b)
{
matrix c;
c[0][0]=std::max(a[0][0]+b[0][0],a[0][1]+b[1][0]),c[0][1]=std::max(a[0][0]+b[0][1],a[0][1]+b[1][1]);
c[1][0]=std::max(a[1][0]+b[0][0],a[1][1]+b[1][0]),c[1][1]=std::max(a[1][0]+b[0][1],a[1][1]+b[1][1]);
return c;
}
void dfs1(int u,int fa)
{
sz[u]=1;
for(int v:e[u]) if(v^fa) if(dfs1(v,u),sz[u]+=sz[v],sz[v]>sz[son[u]]) son[u]=v;
}
struct BST
{
int root,top,ch[N][2],fa[N],stk[N],vis[N],size[N];matrix f[N],sum[N];
void init(){f[0][0][0]=f[0][1][1]=sum[0][0][0]=sum[0][1][1]=0;for(int i=1;i<=n;++i)f[i][0][1]=val[i],f[i][0][0]=f[i][1][0]=0;}
int nroot(int p){return ch[fa[p]][0]==p||ch[fa[p]][1]==p;}
void pushup(int p){sum[p]=sum[ch[p][0]]*f[p]*sum[ch[p][1]];}
void merge(int u,int v){f[u][1][0]+=sum[v].cal(),f[u][0][0]=f[u][1][0],f[u][0][1]+=std::max(sum[v][0][0],sum[v][1][0]),fa[v]=u;}
int build(int l,int r)
{
if(l>r) return 0;
int tot=0;for(int i=l;i<=r;++i)tot+=size[stk[i]];
for(int i=l,now=size[stk[i]],ls,rs;i<=r;++i,now+=size[stk[i]])
if(2*now>=tot)
return ls=build(l,i-1),rs=build(i+1,r),ch[stk[i]][0]=rs,ch[stk[i]][1]=ls,fa[ls]=fa[rs]=stk[i],pushup(stk[i]),stk[i];
}
int build(int p)
{
for(int u=p;u;u=son[u]) vis[u]=1;
for(int u=p;u;u=son[u]) for(int v:e[u]) if(!vis[v]) merge(u,build(v));
top=0;for(int u=p;u;u=son[u])stk[++top]=u,size[u]=sz[u]-sz[son[u]];
return build(1,top);
}
void update(int u,int w)
{
f[u][0][1]+=w-val[u],val[u]=w;
for(int v=u;v;v=fa[v])
if(!nroot(v)&&fa[v])
{
f[fa[v]][0][0]-=sum[v].cal(),f[fa[v]][1][0]=f[fa[v]][0][0],f[fa[v]][0][1]-=std::max(sum[v][0][0],sum[v][1][0]);
pushup(v);
f[fa[v]][0][0]+=sum[v].cal(),f[fa[v]][1][0]=f[fa[v]][0][0],f[fa[v]][0][1]+=std::max(sum[v][0][0],sum[v][1][0]);
}
else pushup(v);
}
}bst;
int main()
{
fread(ibuf,1,1<<27,stdin);
n=read(),q=read();
for(int i=1;i<=n;++i) val[i]=read();
for(int i=1,u,v;i<n;++i) u=read(),v=read(),e[u].push_back(v),e[v].push_back(u);
dfs1(1,0),bst.init(),bst.root=bst.build(1);
for(int i=1,u,w;i<=q;++i) u=read(),w=read(),bst.update(u,w),printf("%d\n",bst.sum[bst.root].cal());
}