动态 dp 学习笔记
一、矩阵乘法
普通矩阵乘法
相信大家对矩阵乘法都不陌生,普通的矩阵乘法定义如下:
对于 \(n\times m\) 的矩阵 \(A\) 和 \(m\times q\) 的矩阵 \(B\) ,定义 \(C=A\cdot B\) ,其中:
\[c_{i,j}=\sum_{k=1}^ma_{i,k}\cdot b_{k,j}\\ \]矩阵 \(C\) 的大小为 \(n\times k\) ,单次矩阵乘法的时间复杂度为 \(O(nmk)\) 。
可以简记为:\(C\) 中第 \(i\) 行第 \(j\) 列的元素,等于 \(A\) 的第 \(i\) 行与 \(B\) 的第 \(j\) 列对应相乘再相加。
mat operator*(const mat &a,const mat &b)///传参O(1),传矩阵O(n^2)
{
static mat c;
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
{
c.v[i][j]=0;
for(int k=1;k<=n;k++) c.v[i][j]=(c.v[i][j]+1ll*a.v[i][k]*b.v[k][j])%mod;
}
return c;
}
广义矩阵乘法
不过在动态 \(\texttt{DP}\) 中,更常见的是 \((\min,+),(\max,+)\) 广义矩阵乘法:
常用卡常方法
设矩阵阶数为 \(w\) ,由于**矩阵乘法时间复杂度 \(O(w^3)\) **,而题目中 \(w\) 一般为常数,所以配上数据结构后很容易成为卡常重灾区。
矩阵乘法有两种常见卡常方法:
-
循环展开。
原理是手动展开从而避免使用
for
循环,仅适用于 \(n\) 很小的情形(一般\(n=2\))。struct mat { int a,b,c,d; }; mat operator*(const mat &x,const mat &y) { return {min(x.a+y.a,x.b+y.b),min(x.a+y.b,x.b+y.d),min(x.c+y.a,x.d+y.c),min(x.c+y.b,x.d+y.d)}; }
-
减少取模。
如果模数 \(p\approx 10^9\) ,那么
long long
可以承受 \(9\cdot p^2\) 的数据量。mat operator*(const mat &a,const mat &b) { static mat c; for(int i=1;i<=n;i++) for(int j=1;j<=n;j++) { ll res=0; for(int k=1;k<=n;k++) res+=1ll*a.v[i][k]*b.v[k][j]; c.v[i][j]=res%mod; } return c; }
二、动态 \(\texttt{DP}\) 概述
动态 \(\texttt{DP}\) 用于解决带修树形 \(\texttt{DP}\) 问题。
话不多说,先上模板题。
约定 \(\sum\limits_{v\in son(u)}\) 表示对 \(u\) 的所有子节点 \(v\) 求和, \(\sum\limits_{v\neq wson_u}\) 表示对 \(u\) 的所有轻儿子 \(v\) 求和。
例1、\(\texttt{P4719 【模板】"动态 DP"\&动态树分治}\)
题目描述
给定一棵 \(n\) 个节点的树,点有点权 \(w_i\) 。
\(m\) 次单点修改点权的操作,每次操作后询问最大权独立集。
数据范围
- \(1\le n,m\le 10^5,0\le |w_i|\le 10^2\) 。
时间限制 \(\texttt{1s}\) ,空间限制 \(\texttt{250MB}\) 。
分析
先考虑链怎么做。
\(f_{i,0/1}\) 表示仅考虑前 \(i\) 个数,不选/选第 \(i\) 个节点的最大收益。
写成 \((\max,+)\) 广义矩阵乘法:
修改点权等价于修改单个矩阵,线段树维护单点修改区间乘积即可。
再考虑树的情况。
\(f_{u,0/1}\) 表示仅考虑 \(u\) 子树,不选\(/\)选 \(u\) 的最大收益。
树链剖分转化为链上的问题,发现重儿子比较特殊,我们把轻儿子放在一起考虑。
\(g_{u,0/1}\) 表示仅考虑 \(u\) 自身及其轻子树,不选\(/\)选 \(u\) 的最大收益。
看起来好像变复杂了,但我们可以反过来用 \(g\) 化简 \(f\) 。
终于转化成熟悉的 \((\max,+)\) 广义矩阵乘法!
然后考虑怎么带修。
有一个显然的性质:只有 \(u\) 到根的路径上的点的 \(f\) 值会改变。
求 \(f\) 是简单的,线段树查询重链( \(dfn\) 区间)的矩阵乘积即可。
然而我们更关心哪些节点的 \(g\) 值会改变,因为 \(g\) 的变化会引起矩阵的变化。
结论是,只有 \(u\) 自身及所有 fa[top[u]]
的 \(g\) 值会改变!
从 top[u]
跳到 fa[top[u]]
的过程中,我们要先减掉原本对 fa[top[u]]
的贡献,更新信息后再加入新的对 fa[top[u]]
的贡献。
时间复杂度 \(\mathcal O(w^3n\log^2n)\) ,本题 \(w=2\) 。
至此本题的大致思路就讲完了,还有一些代码实现上的细节问题。
-
初始的 \(f,g\) 需要预处理。
考虑到 \(f\) 需要遍历所有儿子, \(g\) 需要遍历所有重儿子,因此,在树剖第一遍 \(dfs\) 中预处理 \(f\) ,第二遍 \(dfs\) 中预处理 \(g\) ,写起来会比较简洁。
-
线段树的
pushup
要用右边乘左边。为什么呢?矩阵乘法不满足交换律,又因为一条重链是在从底向上更新 \(\texttt{dp}\) 值,所以体现在 \(dfs\) 序上就是从右往左乘。
-
树剖时预处理每条重链的链底,查询时需要
query
整条链。修改时
query
的区间为[dfn[top[u]],ed[u]]
(ed[u]
和ed[top[u]]
本质相同),注意不要把右端点写成u
!
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+5,inf=1e9;
int m,n,u,v,cnt;
int a[maxn];
int d[maxn],fa[maxn],sz[maxn],son[maxn];
int ed[maxn],id[maxn],dfn[maxn],top[maxn];
int f[maxn][2],g[maxn][2];
vector<int> h[maxn];
struct mat
{
int v[2][2];
};
mat operator*(const mat &a,const mat &b)
{
static mat c;
for(int i=0;i<=1;i++)
for(int j=0;j<=1;j++)
{
c.v[i][j]=-inf;
for(int k=0;k<=1;k++) c.v[i][j]=max(c.v[i][j],a.v[i][k]+b.v[k][j]);
}
return c;
}
namespace sgmt
{
#define ls p<<1
#define rs p<<1|1
struct node
{
int l,r;
mat x;
}f[4*maxn];
void pushup(int p)
{
f[p].x=f[rs].x*f[ls].x;
}
void build(int p,int l,int r)
{
f[p].l=l,f[p].r=r;
if(l==r)
{
int x=id[l];
return f[p].x={g[x][0],g[x][1],g[x][0],-inf},void();
}
int mid=(l+r)/2;
build(ls,l,mid);
build(rs,mid+1,r);
pushup(p);
}
void modify(int p,int pos,mat x)
{
if(f[p].l==f[p].r) return f[p].x=x,void();
int mid=(f[p].l+f[p].r)/2;
if(pos<=mid) modify(ls,pos,x);
else modify(rs,pos,x);
pushup(p);
}
mat query(int p,int l,int r)
{
if(l<=f[p].l&&f[p].r<=r) return f[p].x;
int mid=(f[p].l+f[p].r)/2;
if(r<=mid) return query(ls,l,r);
if(l>=mid+1) return query(rs,l,r);
return query(rs,l,r)*query(ls,l,r);///注意这里是从右往左乘
}
}
void dfs1(int u,int father)
{
sz[u]=1,f[u][1]=a[u];
for(auto v:h[u])
{
if(v==father) continue;
d[v]=d[u]+1,fa[v]=u;
dfs1(v,u),sz[u]+=sz[v];
if(sz[v]>=sz[son[u]]) son[u]=v;
f[u][0]+=max(f[v][0],f[v][1]),f[u][1]+=f[v][0];
}
}
void dfs2(int u,int topf)
{
dfn[u]=++cnt,id[cnt]=u,top[u]=topf,ed[u]=dfn[u],g[u][1]=a[u];
if(son[u]) dfs2(son[u],topf),ed[u]=ed[son[u]];
for(auto v:h[u])
{
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
g[u][0]+=max(f[v][0],f[v][1]),g[u][1]+=f[v][0];
}
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&u,&v);
h[u].push_back(v),h[v].push_back(u);
}
d[1]=1,dfs1(1,0),dfs2(1,1);
sgmt::build(1,1,n);
while(m--)
{
scanf("%d%d",&u,&v);
g[u][1]+=v-a[u],a[u]=v;///更新u自身的信息
while(u)
{
mat lst=sgmt::query(1,dfn[top[u]],ed[u]);
sgmt::modify(1,dfn[u],{g[u][0],g[u][1],g[u][0],-inf});
mat now=sgmt::query(1,dfn[top[u]],ed[u]);
u=fa[top[u]];
g[u][0]-=max(lst.v[0][0],lst.v[0][1]),g[u][1]-=lst.v[0][0];///减掉旧的贡献
g[u][0]+=max(now.v[0][0],now.v[0][1]),g[u][1]+=now.v[0][0];///加入新的贡献
}
mat res=sgmt::query(1,dfn[1],ed[1]);
printf("%d\n",max(res.v[0][0],res.v[0][1]));
}
return 0;
}
总结
动态 \(\texttt{DP}\) 的操作流程:
-
在树剖过程中预处理 \(f,g\) 。
这里 \(f\) 表示考虑 \(u\) 子树的贡献, \(g\) 表示考虑 \(u\) 及其所有轻子树的贡献。
注意 \(g\) 的转移依赖于 \(f\) ,而不是 \(g\) 自身封闭转移。
一般来说,\(f\) 仅在预处理时会用到,修改后可以通过线段树查询得到真实的 \(f\) ;而 \(g\) 会在转移矩阵有所体现,修改需要实时维护。
-
线段树维护转移矩阵。
注意
pushup
是从右往左乘。void pushup(int p) { f[p].x=f[rs].x*f[ls].x; }
-
修改时先减掉原来的贡献,再加入更新之后的贡献。
下面是修改操作的伪代码:
///单点修改g[u],注意不需要上线段树 while(u) { mat lst=/**初始矩阵**/ * sgmt::query(1,dfn[top[u]],ed[u]); sgmt::modify(1,dfn[u],/**u的转移矩阵**/); mat now=/**初始矩阵**/ * sgmt::query(1,dfn[top[u]],ed[u]); u=fa[top[u]]; ///在g[u]中减掉lst的贡献 ///在g[u]中加入now的贡献 }
-
查询时 \(u\) 的答案为初始矩阵乘上自底向上整条链的转移矩阵。
如果初始矩阵刚好为矩阵乘法的单位元,那么可以偷懒不乘初始矩阵。比如非负整数域上 \((\max,+)\) 广义矩阵乘法的单位元为全零矩阵,模板题就是典型例子。
但是多数情况下初始矩阵不是单位元,也不可省略。
ans[u]=/**初始矩阵**/ * sgmt::query(1,dfn[u],ed[u]);
-
如果没有修改操作,或者所有查询操作在修改操作之后,然后询问多个点(多条路径)的答案,可以用倍增代替树剖,时间复杂度少一只 \(\log\) 。
三、相关例题
例2、\(\texttt{CF750E New Year and Old Subsequence}\)
题目描述
给定一个长为 \(n\) 的数字串 \(s\) 。
\(q\) 次询问,对于 \(s\) 的某个子串 \(s[l:r]\) ,至少要删去几个字符,才能使其包含序列 "2017"
,但不包含序列 "2016"
,无解输出 -1
。
数据范围
- \(4\le n\le 2\cdot 10^5,1\le q\le 2\cdot 10^5,1\le l\le r\le n\) 。
时间限制 \(\texttt{3s}\) ,空间限制 \(\texttt{256MB}\) 。
分析
严格来说,本题不算动态 \(\texttt{DP}\) ,只能算线段树维护矩阵乘法。
注意认真审题,要求包含的是子序列而不是子串。
考虑在子序列自动机上 \(\texttt{dp}\) ,状态设计如下:
- \(f_{i,0}\) 表示走到状态 \(\varnothing\) ,至少需要删几个字符。
- \(f_{i,1}\) 表示走到状态
"2"
,至少需要删几个字符。 - \(f_{i,2}\) 表示走到状态
"20"
,至少需要删几个字符。 - \(f_{i,3}\) 表示走到状态
"201"
,至少需要删几个字符。 - \(f_{i,4}\) 表示走到状态
"2017"
,至少需要删几个字符。
记 \(c=s_i\) ,可以写出转移方程:
注: \(f[c=x]\) 要求 \(c=x\) 时才能转移, \(f+[c=x]\) 将 \(c=x\) 视为一个布尔函数。
至此已经做到 \(\mathcal O(nq)\) ,考虑继续优化。
把转移方程看成 \((\min,+)\) 广义矩阵乘法,那么每个 \(c\) 都可以预处理出转移矩阵。
每次询问用初始矩阵 \(\begin{bmatrix}0&\infty&\infty&\infty&\infty\\\end{bmatrix}\) 乘上区间 \([l,r]\) 的所有矩阵,线段树维护即可。
时间复杂度 \(\mathcal O(w^3(n+q)\log n)\) ,本题 \(w=5\) 。
可以用向量乘矩阵的 \(\texttt{trick}\) 优化到\(O(w^3n\log n+w^2q\log n)\),但本题 \(n,q\) 同阶,所以意义不大。
#include<bits/stdc++.h>
#define ls p<<1
#define rs p<<1|1
using namespace std;
const int maxn=2e5+5,inf=1e9;
int l,n,q,r;
char s[maxn];
struct mat
{
int v[5][5];
mat()
{
for(int i=0;i<=4;i++)
for(int j=0;j<=4;j++)
v[i][j]=inf;
}
}c[10];
mat operator*(const mat &a,const mat &b)
{
static mat c;
for(int i=0;i<=4;i++)
for(int j=0;j<=4;j++)
{
c.v[i][j]=inf;
for(int k=0;k<=4;k++) c.v[i][j]=min(c.v[i][j],a.v[i][k]+b.v[k][j]);
}
return c;
}
struct node
{
int l,r;
mat x;
}f[4*maxn];
void build(int p,int l,int r)
{
f[p].l=l,f[p].r=r;
if(l==r) return f[p].x=c[s[l]-'0'],void();
int mid=(l+r)/2;
build(ls,l,mid);
build(rs,mid+1,r);
f[p].x=f[ls].x*f[rs].x;
}
mat query(int p,int l,int r)
{
if(l<=f[p].l&&f[p].r<=r) return f[p].x;
int mid=(f[p].l+f[p].r)/2;
if(r<=mid) return query(ls,l,r);
if(l>=mid+1) return query(rs,l,r);
return query(ls,l,r)*query(rs,l,r);
}
int main()
{
scanf("%d%d%s",&n,&q,s+1);
for(int i=0;i<=9;i++)
for(int j=0;j<=4;j++)
c[i].v[j][j]=0;
c[2].v[0][0]=1,c[2].v[0][1]=0;
c[0].v[1][1]=1,c[0].v[1][2]=0;
c[1].v[2][2]=1,c[1].v[2][3]=0;
c[7].v[3][3]=1,c[7].v[3][4]=0;
c[6].v[3][3]=1,c[6].v[4][4]=1;
build(1,1,n);
while(q--)
{
scanf("%d%d",&l,&r);
mat res;
res.v[0][0]=0,res=res*query(1,l,r);
printf("%d\n",res.v[0][4]!=inf?res.v[0][4]:-1);
}
return 0;
}
例3、\(\texttt{P6021 洪水}\)
题目描述
给定一棵 \(n\) 个节点的树,你可以花费 \(w_i\) 的代价删掉第 \(i\) 个点。
接下来 \(m\) 次操作:
Q x
:询问如果要使 \(x\) 与其子树中所有叶节点不连通,花费代价的最小值。C x y
:将 \(w_x\) 加上 \(y\) 。
数据范围
- \(1\le n,m\le 2\cdot 10^5\) 。
- 保证任意时刻 \(0\le w_i\lt 2^{31}\) 。
时间限制 \(\texttt{1s}\) ,空间限制 \(\texttt{125MB}\) 。
分析
记 \(f_u\) 为让 \(u\) 与其子树所有叶子不连通的最小代价,容易写出转移方程:
记 \(g_u=\sum\limits_{v\neq wson_u}f_v\) ,则 \(f_u=\min(w_u,f_{wson_u}+g_u)\) 。
然后写成 \((\min,+)\) 广义矩阵乘法:
动态 \(\texttt{DP}\) 维护即可,时间复杂度 \(\mathcal O(w^3n\log^2n)\) ,本题 \(w=2\) 。
注意叶子节点 \(wson_u\) 不存在,因此 query
时需要用初始矩阵 \(\begin{bmatrix}\infty&0\\\end{bmatrix}\) 乘上整条链的贡献。
#include<bits/stdc++.h>
#define ll long long
#define ls p<<1
#define rs p<<1|1
using namespace std;
const int maxn=2e5+5;
const ll inf=1e18;
int m,n,u,v,cnt;
int d[maxn],fa[maxn],sz[maxn],son[maxn];
int ed[maxn],id[maxn],dfn[maxn],top[maxn];
ll f[maxn],g[maxn],w[maxn];
char ch[2];
vector<int> h[maxn];
struct mat
{
ll v[2][2];
};
mat operator*(const mat &a,const mat &b)
{
static mat c;
for(int i=0;i<=1;i++)
for(int j=0;j<=1;j++)
{
c.v[i][j]=inf;
for(int k=0;k<=1;k++) c.v[i][j]=min(c.v[i][j],a.v[i][k]+b.v[k][j]);
}
return c;
}
namespace sgmt
{
struct node
{
int l,r;
mat x;
}f[4*maxn];
void pushup(int p)
{
f[p].x=f[rs].x*f[ls].x;
}
void build(int p,int l,int r)
{
f[p].l=l,f[p].r=r;
if(l==r)
{
int u=id[l];
return f[p].x={g[u],inf,w[u],0},void();
}
int mid=(l+r)/2;
build(ls,l,mid);
build(rs,mid+1,r);
pushup(p);
}
void modify(int p,int pos,mat x)
{
if(f[p].l==f[p].r) return f[p].x=x,void();
int mid=(f[p].l+f[p].r)/2;
if(pos<=mid) modify(ls,pos,x);
else modify(rs,pos,x);
pushup(p);
}
mat query(int p,int l,int r)
{
if(l<=f[p].l&&f[p].r<=r) return f[p].x;
int mid=(f[p].l+f[p].r)/2;
if(r<=mid) return query(ls,l,r);
if(l>=mid+1) return query(rs,l,r);
return query(rs,l,r)*query(ls,l,r);
}
}
using sgmt::modify;
using sgmt::query;
void dfs1(int u,int father)
{
sz[u]=1;
for(auto v:h[u])
{
if(v==father) continue;
d[v]=d[u]+1,fa[v]=u;
dfs1(v,u),sz[u]+=sz[v];
if(sz[v]>=sz[son[u]]) son[u]=v;
f[u]+=f[v];
}
f[u]=sz[u]==1?w[u]:min(f[u],w[u]);
}
void dfs2(int u,int topf)
{
dfn[u]=++cnt,id[cnt]=u,top[u]=topf,ed[u]=cnt;
if(son[u]) dfs2(son[u],topf),ed[u]=ed[son[u]];
for(auto v:h[u])
{
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
g[u]+=f[v];
}
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%lld",&w[i]);
for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&u,&v);
h[u].push_back(v),h[v].push_back(u);
}
dfs1(1,0),dfs2(1,1);
sgmt::build(1,1,n);
scanf("%d",&m);
while(m--)
{
scanf("%s",ch);
if(ch[0]=='Q')
{
scanf("%d",&u);
mat res=(mat){inf,0,0,0}*query(1,dfn[u],ed[u]);
printf("%lld\n",res.v[0][0]);
}
else
{
scanf("%d%d",&u,&v),w[u]+=v;
while(u)
{
mat lst=(mat){inf,0,0,0}*query(1,dfn[top[u]],ed[u]);
modify(1,dfn[u],{g[u],inf,w[u],0});
mat now=(mat){inf,0,0,0}*query(1,dfn[top[u]],ed[u]);
u=fa[top[u]];
g[u]+=now.v[0][0]-lst.v[0][0];
}
}
}
return 0;
}
例4、\(\texttt{P5024 [NOIP2018 提高组] 保卫王国}\)
题目描述
给定一棵 \(n\) 个节点的树,点有点权 \(w_i\) 。
\(m\) 次询问,每次分别钦定 \(a,b\) 必须选/不选,求最小权点覆盖,无解输出 -1
。
数据范围
- \(1\le n,m,w_i\le10^5,1\le a\neq b\le n\) 。
时间限制 \(\texttt{2s}\) ,空间限制 \(\texttt{512MB}\) 。
分析
强制选点/不选点可以通过给权值加上/减去 \(\inf\) 实现。
熟知结论:最小权点覆盖等于总权值减去最大权独立集。
带修的最大权独立集就是模板题干的事情。
时间复杂度 \(\mathcal O(w^3m\log^2n)\) ,本题 \(w=2\) 。
#include<bits/stdc++.h>
#define ll long long
#define ls p<<1
#define rs p<<1|1
using namespace std;
const int maxn=1e5+5;
const ll inf=1e18;
int a,b,m,n,u,v,cnt;
ll sum;
int d[maxn],fa[maxn],sz[maxn],son[maxn];
int ed[maxn],id[maxn],dfn[maxn],top[maxn];
ll f[maxn][2],g[maxn][2],w[maxn];
vector<int> h[maxn];
struct mat
{
ll v[2][2];
};
mat operator*(const mat &a,const mat &b)
{
static mat c;
for(int i=0;i<=1;i++)
for(int j=0;j<=1;j++)
{
c.v[i][j]=-inf;
for(int k=0;k<=1;k++) c.v[i][j]=max(c.v[i][j],a.v[i][k]+b.v[k][j]);
}
return c;
}
namespace sgmt
{
struct node
{
int l,r;
mat x;
}f[4*maxn];
void pushup(int p)
{
f[p].x=f[rs].x*f[ls].x;
}
void build(int p,int l,int r)
{
f[p].l=l,f[p].r=r;
if(l==r)
{
int u=id[l];
return f[p].x={g[u][0],g[u][1],g[u][0],-inf},void();
}
int mid=(l+r)/2;
build(ls,l,mid);
build(rs,mid+1,r);
pushup(p);
}
void modify(int p,int pos,mat x)
{
if(f[p].l==f[p].r) return f[p].x=x,void();
int mid=(f[p].l+f[p].r)/2;
if(pos<=mid) modify(ls,pos,x);
else modify(rs,pos,x);
pushup(p);
}
mat query(int p,int l,int r)
{
if(l<=f[p].l&&f[p].r<=r) return f[p].x;
int mid=(f[p].l+f[p].r)/2;
if(r<=mid) return query(ls,l,r);
if(l>=mid+1) return query(rs,l,r);
return query(rs,l,r)*query(ls,l,r);
}
}
void dfs1(int u,int father)
{
sz[u]=1,f[u][1]=w[u];
for(auto v:h[u])
{
if(v==father) continue;
d[v]=d[u]+1,fa[v]=u;
dfs1(v,u),sz[u]+=sz[v];
if(sz[v]>=sz[son[u]]) son[u]=v;
f[u][0]+=max(f[v][0],f[v][1]),f[u][1]+=f[v][0];
}
}
void dfs2(int u,int topf)
{
dfn[u]=++cnt,id[cnt]=u,ed[u]=cnt,top[u]=topf,g[u][1]=w[u];
if(son[u]) dfs2(son[u],topf),ed[u]=ed[son[u]];
for(auto v:h[u])
{
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
g[u][0]+=max(f[v][0],f[v][1]),g[u][1]+=f[v][0];
}
}
void modify(int u,ll tag)
{
g[u][1]+=tag,sum+=tag;
while(u)
{
mat lst=sgmt::query(1,dfn[top[u]],ed[u]);
sgmt::modify(1,dfn[u],{g[u][0],g[u][1],g[u][0],-inf});
mat now=sgmt::query(1,dfn[top[u]],ed[u]);
u=fa[top[u]];
g[u][0]-=max(lst.v[0][0],lst.v[0][1]),g[u][1]-=lst.v[0][0];
g[u][0]+=max(now.v[0][0],now.v[0][1]),g[u][1]+=now.v[0][0];
}
}
int main()
{
scanf("%d%d%*s",&n,&m);
for(int i=1;i<=n;i++) scanf("%lld",&w[i]),sum+=w[i];
for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&u,&v);
h[u].push_back(v),h[v].push_back(u);
}
dfs1(1,0),dfs2(1,1);
sgmt::build(1,1,n);
while(m--)
{
scanf("%d%d%d%d",&u,&a,&v,&b);
modify(u,a?-inf:inf),modify(v,b?-inf:inf);
mat cur=sgmt::query(1,1,ed[1]);
ll res=sum-max(cur.v[0][0],cur.v[0][1])+(a+b)*inf;
printf("%lld\n",res<inf?res:-1);
modify(u,a?inf:-inf),modify(v,b?inf:-inf);
}
return 0;
}
例5、\(\texttt{LOJ3539 「JOI Open 2018」猫或狗}\)
题目描述
给定一棵 \(n\) 个节点的树,点有点权 \(w_i\) ,满足 \(w_i\in\{0,1,2\}\) ,初始全为 \(2\) 。
接下来 \(q\) 次操作:先单点修改点权,再询问为使所有权值为 \(0\) 的点不与权值为 \(1\) 的点连通,至少要删几条边。
数据范围
- \(1\le n,q\le 10^5\) 。
时间限制 \(\texttt{3s}\) ,空间限制 \(\texttt{512MB}\) 。
分析
目标等价于 \(\forall u\) , \(u\) 不能既与 \(0\) 连通又不与 \(1\) 连通。
\(f_{u,0/1}\)表示使点 \(u\) 不和子树内的 \(0/1\) 连通,至少要删几条边。
注:对于本题,中括号不满足时该项视为\(\infty\)。
然后分离轻重子树的贡献。定义:
再用 \(g\) 化简 \(f\) :
写成 \((\min,+)\) 转移矩阵:
然后动态 \(\texttt{DP}\) 就可以了。
时间复杂度 \(\mathcal O(w^3(n+q)\log^2n)\) ,本题 \(w=2\) 。
#include<bits/stdc++.h>
#include"catdog.h"
#define ls p<<1
#define rs p<<1|1
using namespace std;
const int maxn=1e5+5,inf=1e9;
int n,cnt;
int d[maxn],fa[maxn],sz[maxn],son[maxn];
int w[maxn],ed[maxn],id[maxn],dfn[maxn],top[maxn];
int f[maxn][2],g[maxn][2];
vector<int> e[maxn];
struct mat
{
int v[2][2];
};
mat operator*(const mat &a,const mat &b)
{
static mat c;
for(int i=0;i<=1;i++)
for(int j=0;j<=1;j++)
{
c.v[i][j]=inf;
for(int k=0;k<=1;k++) c.v[i][j]=min(c.v[i][j],a.v[i][k]+b.v[k][j]);
}
return c;
}
namespace sgmt
{
struct node
{
int l,r;
mat x;
void init(int u)
{
x={w[u]!=0?g[u][0]:inf,w[u]!=1?g[u][1]+1:inf,w[u]!=0?g[u][0]+1:inf,w[u]!=1?g[u][1]:inf};
}
}f[maxn<<2];
void pushup(int p)
{
f[p].x=f[rs].x*f[ls].x;
}
void build(int p,int l,int r)
{
f[p].l=l,f[p].r=r;
if(l==r) return f[p].init(id[l]);
int mid=(l+r)/2;
build(ls,l,mid);
build(rs,mid+1,r);
pushup(p);
}
void modify(int p,int pos)
{
if(f[p].l==f[p].r) return f[p].init(id[pos]);
int mid=(f[p].l+f[p].r)/2;
if(pos<=mid) modify(ls,pos);
else modify(rs,pos);
pushup(p);
}
mat query(int p,int l,int r)
{
if(l<=f[p].l&&f[p].r<=r) return f[p].x;
int mid=(f[p].l+f[p].r)/2;
if(r<=mid) return query(ls,l,r);
if(l>=mid+1) return query(rs,l,r);
return query(rs,l,r)*query(ls,l,r);
}
}
void dfs1(int u,int father)
{
sz[u]=1;
for(auto v:e[u])
{
if(v==father) continue;
d[v]=d[u]+1,fa[v]=u;
dfs1(v,u),sz[u]+=sz[v];
if(sz[v]>=sz[son[u]]) son[u]=v;
f[u][0]+=min(f[v][0],f[v][1]+1),f[u][1]+=min(f[v][1],f[v][0]+1);
}
if(w[u]!=2) f[u][w[u]]=inf;
}
void dfs2(int u,int topf)
{
dfn[u]=++cnt,id[cnt]=u,ed[u]=cnt,top[u]=topf;
if(son[u]) dfs2(son[u],topf),ed[u]=ed[son[u]];
for(auto v:e[u])
{
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
g[u][0]+=min(f[v][0],f[v][1]+1),g[u][1]+=min(f[v][1],f[v][0]+1);
}
}
void initialize(int _n,vector<int> a,vector<int> b)
{
n=_n;
for(int i=0;i<n-1;i++)
{
int u=a[i],v=b[i];
e[u].push_back(v),e[v].push_back(u);
}
for(int i=1;i<=n;i++) w[i]=2;
d[1]=1,dfs1(1,0),dfs2(1,1);
sgmt::build(1,1,n);
}
int modify(int u,int x)
{
w[u]=x;
while(u)
{
mat lst=(mat){0,0,0,0}*sgmt::query(1,dfn[top[u]],ed[u]);
sgmt::modify(1,dfn[u]);
mat now=(mat){0,0,0,0}*sgmt::query(1,dfn[top[u]],ed[u]);
u=fa[top[u]];
g[u][0]-=min(lst.v[0][0],lst.v[0][1]+1),g[u][1]-=min(lst.v[0][1],lst.v[0][0]+1);
g[u][0]+=min(now.v[0][0],now.v[0][1]+1),g[u][1]+=min(now.v[0][1],now.v[0][0]+1);
}
mat res=(mat){0,0,0,0}*sgmt::query(1,1,ed[1]);
return min(res.v[0][0],res.v[0][1]);
}
int cat(int u)
{
return modify(u,0);
}
int dog(int u)
{
return modify(u,1);
}
int neighbor(int u)
{
return modify(u,2);
}
例6、\(\texttt{LOJ2269 「SDOI2017」切树游戏}\)
题目描述
给定一棵 \(n\) 个节点的树,点有点权 \(w_i\) 。
接下来 \(q\) 次操作:
Change x y
:将 \(w_x\) 改为 \(y\) 。Query k
:询问有多少个非空连通块,权值异或和为 \(k\) 。
数据范围
- \(1\le n,q\le 3\cdot 10^4\) ,保证
Change
操作个数 \(\le 10^4\) 。 - \(4\le m\le 128\) ,保证 \(m\) 为 \(2\) 的方幂。
- \(0\le w_i,y\lt m\) 。
时间限制 \(\texttt{3s}\) ,空间限制 \(\texttt{512MB}\) 。
分析
\(f_{u,i}\) 表示以 \(u\) 为根的所有连通块中,权值异或和为 \(i\) 的连通块个数。
每次合并一棵 \(u\) 的子树 \(v\) ,转移方程为\(f'_{u,i}=f_{u,i}+\sum\limits_{x\oplus y=i}f_{u,x}\cdot f_{v,y}\)。
这里 \(f_u\)表示转移前的 \(\texttt{dp}\) 值, \(f'_u\) 表示转移后的 \(\texttt{dp}\) 值。
为节约篇幅,记 \(\hat f_u=\text{FWT}[f_u]\) ,注意 \(\hat{f_u+1}=\hat f_u+\hat 1\) 。
两边同时取 \(\text{FWT}\) :
记 \(f_u\) 的初始值 \(p_u=x^{w_u}\) ,最终的转移方程为:
最后输出 \(\sum_{u=1}^nf_{u,k}\) 即可,于是我们得到了一个 \(\mathcal O(nmq)\) 的算法。
由于询问要对所有的 \(u\) 求和,因此再记一个\(h_{u,i}=\sum\limits_{v\in subtree(u)}f_{v,i}\)。
转移方程 \(h_{u,i}=f_{u,i}+\sum\limits_{v\in son(u)}h_{v,i}\)。
两边同时取 \(\text{FWT}\) :
接下来是动态 \(\texttt{DP}\) 的基本操作,分离轻重子树的贡献。
定义:
转移方程为:
然后写成矩阵乘法:
矩阵乘法自带 \(27\) 倍常数,显然无法承受的。
注意到 \(\begin{bmatrix}a&b&0\\0&\hat 1&0\\c&d&\hat 1\\\end{bmatrix}\) 做矩阵乘法时关于这个形式封闭:
我们只需维护 \(a,b,c,d\) 四个值,常数从 \(27\) 降为 \(4\) 。
询问时初始矩阵为 \(\begin{bmatrix}0&0&\hat 1\\\end{bmatrix}\) ,但代码实现没这么麻烦,因为 \(\hat f_u,\hat h_u\) 分别对应转移矩阵中的 \(c,d\) 位置。
具体可以看这篇blog中“基于变换合并的算法”一栏。
但本题并没有结束,动态 \(\texttt{DP}\) 中修改需要删除原来的贡献。
其中 \(\hat l_u\) 可以直接减,但 \(\hat g_u\) 需要做除法,而零没有乘法逆元。
对 \(\forall 1\le u\le n\) ,我们额外用一个数组 z[u]
维护 g[u]
中乘零的次数,这样乘零和除零可以直接在 z[u]
上修改。
时间复杂度 \(\mathcal O(4\cdot 10^4\cdot m\log^2n)\),前面的 \(4\) 为矩阵乘法的常数。
#include<bits/stdc++.h>
#define poly array<int,128>
using namespace std;
const int maxn=3e4+5,mod=1e4+7,inv2=(mod+1)>>1;
int m,n,q,u,v,cnt;
int ed[maxn],fa[maxn],sz[maxn],dfn[maxn],son[maxn],top[maxn];
int w[maxn],inv[mod];
char ch[10];
poly f[maxn],g[maxn],h[maxn],l[maxn],z[maxn],p[128];
vector<int> e[maxn];
int qpow(int a,int k)
{
int res=1;
for(;k;a=a*a%mod,k>>=1) if(k&1) res=res*a%mod;
return res;
}
int add(int x,int y)
{
if((x+=y)>=mod) x-=mod;
return x;
}
int dec(int x,int y)
{
if((x-=y)<0) x+=mod;
return x;
}
void fwt(poly &f,int n,int op)
{
for(int k=2,m=1;k<=n;k<<=1,m<<=1)
for(int i=0;i<n;i+=k)
for(int j=i;j<i+m;j++)
{
int x=f[j],y=f[j+m];
f[j]=add(x,y),f[j+m]=dec(x,y);
if(op==-1) f[j]=f[j]*inv2%mod,f[j+m]=f[j+m]*inv2%mod;
}
}
poly operator+(poly a,poly b)
{
for(int i=0;i<m;i++) a[i]=add(a[i],b[i]);
return a;
}
poly operator-(poly a,poly b)
{
for(int i=0;i<m;i++) a[i]=dec(a[i],b[i]);
return a;
}
poly operator*(poly a,poly b)
{
for(int i=0;i<m;i++) a[i]=a[i]*b[i]%mod;
return a;
}
void operator+=(poly &a,poly b)
{
a=a+b;
}
void operator-=(poly &a,poly b)
{
a=a-b;
}
void operator*=(poly &a,poly b)
{
a=a*b;
}
void dfs1(int u,int _f)
{
sz[u]=1,f[u]=p[w[u]];
for(auto v:e[u])
{
if(v==_f) continue;
dfs1(v,u),fa[v]=u,sz[u]+=sz[v];
if(sz[v]>=sz[son[u]]) son[u]=v;
f[u]*=f[v]+p[0],h[u]+=h[v];
}
h[u]+=f[u];
}
void dfs2(int u,int topf)
{
dfn[u]=++cnt,top[u]=topf,ed[u]=cnt;
for(int i=0;i<m;i++) g[u][i]=1;
if(son[u]) dfs2(son[u],topf),ed[u]=ed[son[u]];
for(auto v:e[u])
{
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
l[u]+=h[v];
poly tmp=f[v]+p[0];
for(int i=0;i<m;i++) tmp[i]?g[u][i]=g[u][i]*tmp[i]%mod:z[u][i]++;
}
}
struct mat
{
poly a,b,c,d;
mat operator*(const mat &x)
{
return {a*x.a,a*x.b+b,c*x.a+x.c,c*x.b+d+x.d};
}
};
mat get_mat(int u)
{
static poly tmp;
for(int i=0;i<m;i++) tmp[i]=z[u][i]?0:p[w[u]][i]*g[u][i]%mod;
return {tmp,tmp,tmp,l[u]+tmp};
}
namespace sgmt
{///博主为了偷懒写了zkw线段树,但是直接导致用时翻倍,大家不要学我qwq
int p;
mat val[3*maxn];
void build()
{
p=1<<(__lg(n+1)+1);
for(int i=1;i<=n;i++) val[p+dfn[i]]=get_mat(i);
for(int i=(p+n)>>1;i;i--) val[i]=val[i<<1|1]*val[i<<1];
}
void modify(int u)
{
val[p+dfn[u]]=get_mat(u);
for(int i=(p+dfn[u])>>1;i;i>>=1) val[i]=val[i<<1|1]*val[i<<1];
}
mat query(int l,int r)
{
static int st[20];
int top=0;
mat res={::p[0],f[0],f[0],f[0]};///f[0]指代常多项式0,p[0]指代常多项式1经过FWT后的结果
for(l=p+l-1,r=p+r+1;l^r^1;l>>=1,r>>=1)
{
if(~l&1) st[++top]=l^1;
if(r&1) res=res*val[r^1];
}
for(int i=top;i>=1;i--) res=res*val[st[i]];
return res;
}
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&w[i]);
for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&u,&v);
e[u].push_back(v),e[v].push_back(u);
}
for(int i=1;i<mod;i++) inv[i]=qpow(i,mod-2);
for(int i=0;i<m;i++) p[i][i]=1,fwt(p[i],m,1);
dfs1(1,0),dfs2(1,1);
sgmt::build();
for(scanf("%d",&q);q--;)
{
scanf("%s",ch);
if(ch[0]=='C')
{
scanf("%d%d",&u,&v),w[u]=v;
while(u)
{
mat lst=sgmt::query(dfn[top[u]],ed[u]);
sgmt::modify(u);
mat now=sgmt::query(dfn[top[u]],ed[u]);
u=fa[top[u]];
l[u]=l[u]-lst.d+now.d;
for(int i=0;i<m;i++)
{
int x=add(lst.c[i],p[0][i]),y=add(now.c[i],p[0][i]);
x?g[u][i]=g[u][i]*inv[x]%mod:z[u][i]--;
y?g[u][i]=g[u][i]*y%mod:z[u][i]++;
}
}
}
else
{
scanf("%d",&u);
poly tmp=sgmt::query(dfn[1],ed[1]).d;
fwt(tmp,m,-1),printf("%d\n",tmp[u]);
}
}
return 0;
}
本文来自博客园,作者:peiwenjun,转载请注明原文链接:https://www.cnblogs.com/peiwenjun/p/17070870.html