C73【模板】动态DP+树剖+矩阵乘+线段树 P4719 动态树分治
视频链接:259【模板】动态DP+树剖+矩阵乘+线段树 P4719 动态树分治_哔哩哔哩_bilibili
#include <iostream> #include <cstring> #include <algorithm> #include <vector> using namespace std; const int N=100005,inf=0x3f3f3f3f; #define ls (u<<1) #define rs (u<<1|1) #define mid ((l+r)>>1) int n,m,a[N]; vector<int>G[N]; int fa[N],siz[N],son[N],f[N][2]; int dfn[N],id[N],top[N],bot[N],tot; //dfn:dfs序,id:节点编号,top:链头节点,bot:链尾序号 void dfs(int x){ //树剖fa,siz,son,f siz[x]=1; f[x][1]=a[x]; for(int y:G[x]){ if(y==fa[x]) continue; fa[y]=x; dfs(y); siz[x]+=siz[y]; if(siz[y]>siz[son[x]]) son[x]=y; f[x][0]+=max(f[y][0],f[y][1]); f[x][1]+=f[y][0]; } } void dfs(int x,int tp){ //树剖dfn,id,top,bot dfn[x]=++tot;id[tot]=x;top[x]=tp;bot[tp]=tot; if(son[x]) dfs(son[x],tp); for(auto y:G[x]){ if(y==fa[x]||y==son[x]) continue; dfs(y,y); } } struct matrix{ int g[2][2]; matrix operator*(matrix b){ //广义矩阵乘积 matrix t; memset(t.g,-0x3f,sizeof(t.g)); for(int i=0; i<=1; ++i) for(int j=0; j<=1; ++j) for(int k=0; k<=1; ++k) t.g[i][j]=max(t.g[i][j],g[i][k]+b.g[k][j]); return t; } }mt[N],tr[N<<2];//节点g矩阵,线段树g矩阵及乘积 void build(int u,int l,int r){ //建线段树 if(l==r){ int x=id[l],g0=0,g1=a[x]; //g0不选x,g1选x for(auto y:G[x]) if(y!=fa[x]&&y!=son[x]) g0+=max(f[y][0],f[y][1]), //g0选或不选y g1+=f[y][0]; //g1不选y tr[u]=mt[x]={g0,g0,g1,-inf}; return; } build(ls,l,mid); build(rs,mid+1,r); tr[u]=tr[ls]*tr[rs]; //g矩阵乘积 } void change(int u,int l,int r,int p){ //点修 if(l==r){tr[u]=mt[id[l]]; return;} if(p<=mid) change(ls,l,mid,p); else change(rs,mid+1,r,p); tr[u]=tr[ls]*tr[rs]; } matrix query(int u,int l,int r,int x,int y){ //区查 if(x==l&&r==y) return tr[u]; if(y<=mid) return query(ls,l,mid,x,y); if(x>mid) return query(rs,mid+1,r,x,y); return query(ls,l,mid,x,mid)*query(rs,mid+1,r,mid+1,y); } void update(int u,int v){ //修改点权 mt[u].g[1][0]+=v-a[u]; a[u]=v; while(u){ matrix a=query(1,1,n,dfn[top[u]],bot[top[u]]); change(1,1,n,dfn[u]); matrix b=query(1,1,n,dfn[top[u]],bot[top[u]]); u=fa[top[u]]; //跳到链头的父节点 mt[u].g[0][0]+=max(b.g[0][0],b.g[1][0]) -max(a.g[0][0],a.g[1][0]); mt[u].g[0][1]=mt[u].g[0][0]; mt[u].g[1][0]+=b.g[0][0]-a.g[0][0]; } } int main(){ scanf("%d%d",&n,&m); int x,y; for(int i=1;i<=n;i++)scanf("%d",&a[i]); for(int i=1;i<n;i++){ scanf("%d%d",&x,&y); G[x].push_back(y); G[y].push_back(x); } dfs(1);dfs(1,1); //树链剖分 build(1,1,n); //建线段树 while(m--){ scanf("%d%d",&x,&y); update(x,y); //修改点权 matrix ans=query(1,1,n,dfn[1],bot[1]); printf("%d\n",max(ans.g[0][0],ans.g[1][0])); } }
#include <iostream> #include <cstring> #include <algorithm> #include <vector> using namespace std; const int N=100005,inf=0x3f3f3f3f; #define ls (u<<1) #define rs (u<<1|1) #define mid ((l+r)>>1) int n,m,a[N]; vector<int>G[N]; int fa[N],siz[N],son[N],f[N][2]; int dfn[N],id[N],top[N],bot[N],tot; //dfn:dfs序,id:节点编号,top:链头节点,bot:链尾序号 struct matrix{ int g[2][2]; matrix operator*(matrix b){ //广义矩阵乘积 matrix t; memset(t.g,-0x3f,sizeof(t.g)); for(int i=0; i<=1; ++i) for(int j=0; j<=1; ++j) for(int k=0; k<=1; ++k) t.g[i][j]=max(t.g[i][j],g[i][k]+b.g[k][j]); return t; } }mt[N],tr[N<<2];//节点g矩阵,线段树g矩阵及乘积 void dfs(int x){ //树剖fa,siz,son,f siz[x]=1; f[x][1]=a[x]; for(int y:G[x]){ if(y==fa[x]) continue; fa[y]=x; dfs(y); siz[x]+=siz[y]; if(siz[y]>siz[son[x]]) son[x]=y; f[x][0]+=max(f[y][0],f[y][1]); f[x][1]+=f[y][0]; } } void dfs(int x,int tp){ //树剖dfn,id,top,bot,mt dfn[x]=++tot;id[tot]=x;top[x]=tp;bot[tp]=tot; if(son[x]) dfs(son[x],tp); mt[x].g[1][0]=a[x]; mt[x].g[1][1]=-inf; for(auto y:G[x]){ if(y==fa[x]||y==son[x]) continue; dfs(y,y); mt[x].g[0][0]+=max(f[y][0],f[y][1]); mt[x].g[0][1]=mt[x].g[0][0]; mt[x].g[1][0]+=f[y][0]; } } void build(int u,int l,int r){ //建线段树 if(l==r){tr[u]=mt[id[l]]; return;} build(ls,l,mid); build(rs,mid+1,r); tr[u]=tr[ls]*tr[rs]; //g矩阵乘积 } void change(int u,int l,int r,int p){ //点修 if(l==r){tr[u]=mt[id[l]]; return;} if(p<=mid) change(ls,l,mid,p); else change(rs,mid+1,r,p); tr[u]=tr[ls]*tr[rs]; } matrix query(int u,int l,int r,int x,int y){ //区查 if(x==l&&r==y) return tr[u]; if(y<=mid) return query(ls,l,mid,x,y); if(x>mid) return query(rs,mid+1,r,x,y); return query(ls,l,mid,x,mid)*query(rs,mid+1,r,mid+1,y); } void update(int u,int v){ //修改点权 mt[u].g[1][0]+=v-a[u]; a[u]=v; while(u){ matrix a=query(1,1,n,dfn[top[u]],bot[top[u]]); change(1,1,n,dfn[u]); matrix b=query(1,1,n,dfn[top[u]],bot[top[u]]); u=fa[top[u]]; //跳到链头的父节点 mt[u].g[0][0]+=max(b.g[0][0],b.g[1][0]) -max(a.g[0][0],a.g[1][0]); mt[u].g[0][1]=mt[u].g[0][0]; mt[u].g[1][0]+=b.g[0][0]-a.g[0][0]; } } int main(){ scanf("%d%d",&n,&m); int x,y; for(int i=1;i<=n;i++)scanf("%d",&a[i]); for(int i=1;i<n;i++){ scanf("%d%d",&x,&y); G[x].push_back(y); G[y].push_back(x); } dfs(1);dfs(1,1); //树链剖分 build(1,1,n); //建线段树 while(m--){ scanf("%d%d",&x,&y); update(x,y); //修改点权 matrix ans=query(1,1,n,dfn[1],bot[1]); printf("%d\n",max(ans.g[0][0],ans.g[1][0])); } }
// 常数优化的方法 #include <iostream> #include <cstring> #include <algorithm> using namespace std; int read(){ int s=0,w=1;char ch; while((ch=getchar())>'9'||ch<'0') if(ch=='-') w=-1; while(ch>='0'&&ch<='9') s=s*10+(ch^48),ch=getchar(); return s*w; } const int N=100005,inf=0x3f3f3f3f; #define mid ((l+r)>>1) int h[N],to[N<<1],ne[N<<1],idx; void add(int a,int b){to[++idx]=b,ne[idx]=h[a],h[a]=idx;} int n,m,a[N]; int root[N],ls[N<<2],rs[N<<2],nod; int fa[N],siz[N],son[N],f[N][2]; int dfn[N],id[N],top[N],bot[N],tot; //dfn:dfs序,id:节点编号,top:链头节点,bot:链尾序号 struct matrix{ int g[2][2]; matrix operator*(matrix b){ //广义矩阵乘积 matrix t; t.g[0][0]=max(g[0][0]+b.g[0][0],g[0][1]+b.g[1][0]); t.g[1][0]=max(g[1][0]+b.g[0][0],g[1][1]+b.g[1][0]); t.g[0][1]=max(g[0][0]+b.g[0][1],g[0][1]+b.g[1][1]); t.g[1][1]=max(g[1][0]+b.g[0][1],g[1][1]+b.g[1][1]); return t; } }mt[N],tr[N<<2]; //节点g矩阵, 线段树g矩阵及乘积 void dfs(int x){ //树剖fa,siz,son,f siz[x]=1; f[x][1]=a[x]; for(int i=h[x];i;i=ne[i]){ int y=to[i]; if(y==fa[x]) continue; fa[y]=x; dfs(y); siz[x]+=siz[y]; if(siz[y]>siz[son[x]]) son[x]=y; f[x][0]+=max(f[y][1],f[y][0]); f[x][1]+=f[y][0]; } } void dfs(int x,int tp){ //树剖dfn,id,top,bot,mt dfn[x]=++tot;id[tot]=x;top[x]=tp;bot[tp]=tot; if(son[x]) dfs(son[x],tp); mt[x].g[1][0]=a[x]; mt[x].g[1][1]=-inf; for(int i=h[x];i;i=ne[i]){ int y=to[i]; if(y==fa[x]||y==son[x]) continue; dfs(y,y); mt[x].g[0][0]+=max(f[y][0],f[y][1]); mt[x].g[0][1]=mt[x].g[0][0]; mt[x].g[1][0]+=f[y][0]; } } void build(int &u,int l,int r){ //建线段树 u=++nod; if(l==r){tr[u]=mt[id[l]]; return;} build(ls[u],l,mid); build(rs[u],mid+1,r); tr[u]=tr[ls[u]]*tr[rs[u]]; } void change(int u,int l,int r,int p){ //点修 if(l==r){tr[u]=mt[id[l]]; return;} if(p<=mid) change(ls[u],l,mid,p); else change(rs[u],mid+1,r,p); tr[u]=tr[ls[u]]*tr[rs[u]]; } matrix query(int u,int l,int r,int x,int y){ //区查 if(x==l&&r==y) return tr[u]; if(y<=mid) return query(ls[u],l,mid,x,y); if(x>mid) return query(rs[u],mid+1,r,x,y); return query(ls[u],l,mid,x,mid)*query(rs[u],mid+1,r,mid+1,y); } void update(int u,int v){ //修改点权 mt[u].g[1][0]+=v-a[u]; a[u]=v; while(u){ matrix a=tr[root[top[u]]]; change(root[top[u]],dfn[top[u]],bot[top[u]],dfn[u]); matrix b=tr[root[top[u]]]; u=fa[top[u]]; //跳到链头的父节点 mt[u].g[0][0]+=max(b.g[0][0],b.g[1][0]) -max(a.g[0][0],a.g[1][0]); mt[u].g[0][1]=mt[u].g[0][0]; mt[u].g[1][0]+=b.g[0][0]-a.g[0][0]; } } int main(){ n=read(),m=read(); int x,y; for(int i=1;i<=n;i++) a[i]=read(); for(int i=1;i<n;i++){ x=read(),y=read(); add(x,y); add(y,x); } dfs(1);dfs(1,1); //树链剖分 for(int i=1;i<=n;i++) //建线段树 if(top[i]==i) build(root[i],dfn[i],bot[i]); for(int i=1;i<=m;i++){ x=read(), y=read(); update(x,y); //修改点权 matrix ans=tr[root[1]]; printf("%d\n",max(ans.g[0][0],ans.g[1][0])); } }
Luogu P4751 【模板】"动态DP"&动态树分治(加强版)
#include <iostream> #include <cstring> #include <algorithm> using namespace std; int read(){ int s=0,w=1;char ch; while((ch=getchar())>'9'||ch<'0') if(ch=='-') w=-1; while(ch>='0'&&ch<='9') s=s*10+(ch^48),ch=getchar(); return s*w; } const int N=1000005,inf=0x3f3f3f3f; #define mid ((l+r)>>1) int h[N],to[N<<1],ne[N<<1],idx; void add(int a,int b){to[++idx]=b,ne[idx]=h[a],h[a]=idx;} int n,m,last,a[N]; int root[N],ls[N<<2],rs[N<<2],nod; int fa[N],siz[N],son[N],dfn[N],id[N],top[N],bot[N],f[N][2],tot; //son:重儿子,dfn:dfs序,id:节点编号,top:链头节点,bot:链尾序号 struct matrix{ int g[2][2]; matrix operator*(matrix b){ //广义矩阵乘积 matrix t; t.g[0][0]=max(g[0][0]+b.g[0][0],g[0][1]+b.g[1][0]); t.g[1][0]=max(g[1][0]+b.g[0][0],g[1][1]+b.g[1][0]); t.g[0][1]=max(g[0][0]+b.g[0][1],g[0][1]+b.g[1][1]); t.g[1][1]=max(g[1][0]+b.g[0][1],g[1][1]+b.g[1][1]); return t; } }mt[N],tr[N<<2]; //节点g矩阵, 线段树g矩阵及乘积 void dfs(int x){ //树剖fa,siz,son,f siz[x]=1; f[x][1]=a[x]; for(int i=h[x];i;i=ne[i]){ int y=to[i]; if(y==fa[x]) continue; fa[y]=x; dfs(y); siz[x]+=siz[y]; if(siz[y]>siz[son[x]]) son[x]=y; f[x][0]+=max(f[y][1],f[y][0]); f[x][1]+=f[y][0]; } } void dfs(int x,int tp){ //树剖dfn,id,top,bot,mt dfn[x]=++tot;id[tot]=x;top[x]=tp;bot[tp]=tot; if(son[x]) dfs(son[x],tp); mt[x].g[1][0]=a[x]; mt[x].g[1][1]=-inf; for(int i=h[x];i;i=ne[i]){ int y=to[i]; if(y==fa[x]||y==son[x]) continue; dfs(y,y); mt[x].g[0][0]+=max(f[y][0],f[y][1]); mt[x].g[0][1]=mt[x].g[0][0]; mt[x].g[1][0]+=f[y][0]; } } void build(int &u,int l,int r){ //建线段树 u=++nod; if(l==r){tr[u]=mt[id[l]]; return;} build(ls[u],l,mid); build(rs[u],mid+1,r); tr[u]=tr[ls[u]]*tr[rs[u]]; } void change(int u,int l,int r,int p){ //点修 if(l==r){tr[u]=mt[id[l]]; return;} if(p<=mid) change(ls[u],l,mid,p); else change(rs[u],mid+1,r,p); tr[u]=tr[ls[u]]*tr[rs[u]]; } matrix query(int u,int l,int r,int x,int y){ //区查 if(x==l&&r==y) return tr[u]; if(y<=mid) return query(ls[u],l,mid,x,y); if(x>mid) return query(rs[u],mid+1,r,x,y); return query(ls[u],l,mid,x,mid)*query(rs[u],mid+1,r,mid+1,y); } void update(int u,int v){ //修改点权 mt[u].g[1][0]+=v-a[u]; a[u]=v; while(u){ matrix a=tr[root[top[u]]]; change(root[top[u]],dfn[top[u]],bot[top[u]],dfn[u]); matrix b=tr[root[top[u]]]; u=fa[top[u]]; //上跳 mt[u].g[0][0]+=max(b.g[0][0],b.g[1][0])-max(a.g[0][0],a.g[1][0]); mt[u].g[0][1]=mt[u].g[0][0]; mt[u].g[1][0]+=b.g[0][0]-a.g[0][0]; } } int main(){ n=read(),m=read(); int x,y; for(int i=1;i<=n;i++) a[i]=read(); for(int i=1;i<n;i++){ x=read(),y=read(); add(x,y); add(y,x); } dfs(1);dfs(1,1); for(int i=1;i<=n;i++) if(top[i]==i) build(root[i],dfn[i],bot[i]); for(int i=1;i<=m;i++){ x=read()^last, y=read(); update(x,y); matrix ans=tr[root[1]]; printf("%d\n",last=max(ans.g[0][0],ans.g[1][0])); } }