动态dp 学习笔记
前言
被迫营业*2,不过去仔细学一下也挺好的。为了营业去学了好多新东西(((
由于本人水平有限,如有不严谨的地方还请指出。
ddp 主要用来处理树上dp问题,有时候出题人比较恶心带上修改,ddp就是用来支持快速修改的。
使用的前提是转移比较简洁,可以写成矩阵,基本上都是单点修改。
我感觉这个算法还是看题比较好理解。
例题一
这里会介绍三种常见的ddp维护方法。
首先考虑不带修改的情况。
设 \(f(u, 0 / 1)\) 分别表示:强制选择 \(u\) 这个点,强制不选择 \(u\) 这个点 时,以 \(u\) 为根的子树的最大独立集。
这时候以 \(u\) 为根的子树的答案就是 \(\max(f(u,0),f(u,1))\)。
有转移
接下去引入 ddp 的思想。
考虑将重儿子和轻儿子分开考虑。
设 \(g(u, 0 / 1)\) 表示,只考虑 \(u\) 的轻儿子,强制选 \(u\) 或不选 \(u\) 这个点时,以 \(u\) 为根的子树的最大独立集。
设 \(wson(u)\) 为 \(u\) 的重儿子。
可以得到
注意转移方程中,把 \(a_u\) 的贡献加到了 \(g(u,1)\) 中,不然第四条多个 \(a_u\) 方程不够简洁,不方便写成矩阵。
把 \(f\) 的转移写成矩阵:
注意上面的 \(*\) 是广义矩阵乘法: \(a_{i,j}=\max\{b_{i,k}+c_{k,j}\}\),这个运算符也是有结合律的。
一开始的时候我们先一趟dp求出 \(f,g\)。
这时候对于点 \(u\) 为根的子树查询答案会非常方便:设 \(u\) 所在重链底端是 \(End(u)\),我们把 \(u\) 到 \(End(u)\) 这段区间的矩阵全部按顺序乘起来就好了。
可以脑补一下这个过程:重链底端是个叶子,然后不断加入重链周围的轻子树以及重儿子,拼凑成了整颗子树。
考虑如何带上修改。
假设修改了点 \(u\)。
直接影响到的是 \(f(u,1),g(u,1)\)。
接着可以想象,往祖先走的时候同重链的 \(f\) 都能被矩阵直接更新,那么影响到的就是所有轻边的转移。
于是考虑计算更新之后对于轻边父亲的 \(g\) 的贡献。
发现 \(g\) 的转移和 \(f\) 有关,并且我们可以知道这条重链的 \(f\) 以及更新前重链的 \(f\),那么把之前的贡献减掉,把现在的贡献加上,就更新完毕了。
快速查询两点间矩阵乘积可以通过 树链剖分+线段树维护区间矩阵乘积 来维护,而且修改也可以通过跳轻边很方便地维护。
总共会跳到 \(O(\log n)\) 条轻边,还有每跳一次线段树修改的 \(O(\log n)\),总复杂度是 \(O(n\log^2 n)\)。
实现的时候建议这种小型矩阵手动展开,常数上可以减小好多。
远古代码。但是好像不是特别丑
Code
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef double db;
#define pb(x) push_back(x)
#define mkp(x,y) make_pair(x,y)
inline int read() {
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)) {if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch))x=x*10+(ch^48),ch=getchar();
return x*f;
}
const int N=100005;
const int M=N<<2;
const int inf=1e8;
int n,m,a[N];
int siz[N],dfn[N],tmr,son[N],fa[N],top[N],rev[N],ed[N],f[N][2];
struct edge{
int nxt,to;
}e[N<<1];
int head[N],num_edge;
void addedge(int fr,int to){
++num_edge;
e[num_edge].nxt=head[fr];
e[num_edge].to=to;
head[fr]=num_edge;
}
struct Matrix{
int a[2][2];
Matrix(){a[0][0]=a[0][1]=a[1][0]=a[1][1]=-inf;}
int*operator[](const int&k){return a[k];}
Matrix operator * (const Matrix&b){
Matrix res;
// for(int i=0;i<2;++i)
// for(int j=0;j<2;++j)
// for(int k=0;k<2;++k)
// res.a[i][j]=max(res.a[i][j],a[i][k]+b.a[k][j]);
res[0][0]=max(a[0][0]+b.a[0][0],a[0][1]+b.a[1][0]);
res[0][1]=max(a[0][0]+b.a[0][1],a[0][1]+b.a[1][1]);
res[1][0]=max(a[1][0]+b.a[0][0],a[1][1]+b.a[1][0]);
res[1][1]=max(a[1][0]+b.a[0][1],a[1][1]+b.a[1][1]);
return res;
}
}mat[N],val[M];
void dfs1(int u,int ft){
siz[u]=1,f[u][1]=a[u];
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;if(v==ft)continue;
fa[v]=u,dfs1(v,u),siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
f[u][0]+=max(f[v][0],f[v][1]);
f[u][1]+=f[v][0];
}
}
void dfs2(int u,int tp){
top[u]=tp,dfn[u]=++tmr,rev[tmr]=u;
if(son[u])dfs2(son[u],tp),ed[u]=ed[son[u]];
else ed[u]=u;
int g[2];g[0]=0,g[1]=a[u];
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==son[u]||v==fa[u])continue;
dfs2(v,v);
g[0]+=max(f[v][0],f[v][1]);
g[1]+=f[v][0];
}
mat[u][0][0]=g[0],mat[u][0][1]=g[0];
mat[u][1][0]=g[1],mat[u][1][1]=-inf;
}
#define lc (p<<1)
#define rc (p<<1|1)
void pushup(int p){val[p]=val[lc]*val[rc];}
void build(int l,int r,int p){
if(l==r)return val[p]=mat[rev[l]],void();
int mid=(l+r)>>1;
build(l,mid,lc),build(mid+1,r,rc);
pushup(p);
}
Matrix query(int ql,int qr,int l=1,int r=n,int p=1){
if(ql<=l&&r<=qr)return val[p];
int mid=(l+r)>>1;
if(qr<=mid)return query(ql,qr,l,mid,lc);
if(mid<ql)return query(ql,qr,mid+1,r,rc);
return query(ql,qr,l,mid,lc)*query(ql,qr,mid+1,r,rc);
}
void change(int pos,int l=1,int r=n,int p=1){
if(l==r)return val[p]=mat[rev[l]],void();
int mid=(l+r)>>1;
if(pos<=mid)change(pos,l,mid,lc);
else change(pos,mid+1,r,rc);
pushup(p);
}
void update(int x,int v){
mat[x][1][0]+=v-a[x],a[x]=v;
while(x){
Matrix lst=query(dfn[top[x]],dfn[ed[x]]);
change(dfn[x]);
Matrix now=query(dfn[top[x]],dfn[ed[x]]);
x=fa[top[x]];
mat[x][0][0]+=max(now[0][0],now[1][0])-max(lst[0][0],lst[1][0]);
mat[x][0][1]=mat[x][0][0];
mat[x][1][0]+=now[0][0]-lst[0][0];
}
}
signed main(){
n=read(),m=read();
for(int i=1;i<=n;++i)a[i]=read();
for(int i=1;i<n;++i){
int x=read(),y=read();
addedge(x,y),addedge(y,x);
}
dfs1(1,0),dfs2(1,1),build(1,n,1);
while(m--){
int x=read(),v=read();
update(x,v);
Matrix t=query(dfn[1],dfn[ed[1]]);
printf("%d\n",max(t[0][0],t[1][0]));
}
return 0;
}
树剖+线段树的复杂度是两只 log,这使得人们思考有没有更快的方法。
可以发现上面那种方法的在做的其实就是维护链上矩阵积,维护链上信息使我们想到了 \(O(n\log n)\) 的 LCT。
考虑只维护实边信息,虚儿子信息在 access 的时候更新上去。
更新一个节点信息的时候可以先 access 再 splay,这时候修改它对于任何节点都是没有影响的,可以直接修改。
查询一个节点的信息会有点特殊,需要执行的操作是:\(access(fa_x),splay(x)\)。因为把 \(fa_x\) 以上的节点移到别的 splay 里面去,\(splay(x)\) 后 \(x\) 下面挂的节点才是 \(x\) 的子树内的节点。
这里附上一份 LCT 实现,顺便去学了一下。
成功把复杂度降掉一只 \(\log\),LCT 的常数非常大就是了。
说句闲话,这题貌似正常常数的 LCT 都能过去(
Code
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mkp make_pair
#define pb push_back
#define sz(v) (int)(v).size()
typedef long long LL;
typedef double db;
template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
#define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
#define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
inline int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return f?x:-x;
}
const int N = 1000005;
const int inf = 0x3f3f3f3f;
int n, m, a[N], lastans;
vector<int> e[N];
struct Matrix {
int a[2][2];
Matrix(){ memset(a, -0x3f, sizeof a); }
inline int* operator [](const int &k) { return a[k]; }
inline Matrix operator * (const Matrix &t) const {
Matrix res;
res.a[0][0] = max(a[0][0] + t.a[0][0], a[0][1] + t.a[1][0]);
res.a[0][1] = max(a[0][0] + t.a[0][1], a[0][1] + t.a[1][1]);
res.a[1][0] = max(a[1][0] + t.a[0][0], a[1][1] + t.a[1][0]);
res.a[1][1] = max(a[1][0] + t.a[0][1], a[1][1] + t.a[1][1]);
return res;
}
};
int fa[N], ch[N][2], dp[N][2];
Matrix val[N], sum[N];
inline bool nroot(int x) { return ch[fa[x]][0] == x || ch[fa[x]][1] == x; }
inline void pushup(int x) {
sum[x] = val[x];
if(ch[x][0]) sum[x] = sum[ch[x][0]] * sum[x];
if(ch[x][1]) sum[x] = sum[x] * sum[ch[x][1]];
}
inline void rotate(int x) {
int y = fa[x], z = fa[y], k = ch[y][1] == x, w = ch[x][!k];
if(nroot(y)) ch[z][ch[z][1] == y] = x;
ch[x][!k] = y, ch[y][k] = w;
fa[w] = y, fa[y] = x, fa[x] = z;
pushup(y);
}
inline void splay(int x) {
while(nroot(x)) {
int y = fa[x], z = fa[y];
if(nroot(y)) rotate((ch[z][1] == y) ^ (ch[y][1] == x) ? x : y);
rotate(x);
}
pushup(x);
}
inline void access(int x) {
for(int y = 0; x; x = fa[y = x]) {
splay(x);
if(y) {
val[x][0][0] -= max(sum[y][0][0], sum[y][1][0]);
val[x][0][1] = val[x][0][0];
val[x][1][0] -= sum[y][0][0];
}
if(ch[x][1]) {
int t = ch[x][1];
val[x][0][0] += max(sum[t][0][0], sum[t][1][0]);
val[x][0][1] = val[x][0][0];
val[x][1][0] += sum[t][0][0];
}
ch[x][1] = y, pushup(x);
}
}
void dfs(int u, int ft) {
dp[u][1] = a[u], fa[u] = ft;
for(int v : e[u]) if(v != ft) {
dfs(v, u);
dp[u][0] += max(dp[v][0], dp[v][1]);
dp[u][1] += dp[v][0];
}
val[u][0][0] = val[u][0][1] = dp[u][0];
val[u][1][0] = dp[u][1], val[u][1][1] = -inf;
sum[u] = val[u];
}
signed main() {
n = read(), m = read();
rep(i, 1, n) a[i] = read();
rep(i, 2, n) {
int x = read(), y = read();
e[x].pb(y), e[y].pb(x);
}
dfs(1, 0);
while(m--) {
int x = read() ^ lastans, y = read();
access(x), splay(x);
val[x][1][0] += y - a[x], a[x] = y, pushup(x);
splay(1);
printf("%d\n", lastans = max(sum[1][0][0], sum[1][1][0]));
}
return 0;
}
考虑到这棵树并不会动,用 LCT 维护有点浪费,于是考虑搞一种新的方法来划分树。
有人从上古论文里翻出来了一个科技叫做“全局平衡二叉树”。
注意到 LCT 就是每一条链建平衡树,考虑用类似的思想。
建立的方法就是:先树剖,对于每一条重链每次取带权重心建立二叉树,连实边。对于轻子树建立的二叉树的根往当前二叉树上的节点拉虚边。
容易发现在同一颗二叉树内往父亲跳的时候,每跳一次子树大小都会至少翻倍;切换一次二叉树意味着切换一次重边,只会有 \(O(\log n)\) 次。仔细想想,这两个 \(\log\) 并不是乘起来的,是加起来的,因为子树大小翻倍至多 \(\log n\) 次,跳轻链也至多 \(\log n\) 次。所以树高是 \(O(\log n)\) 级别的,粗略分析上限是 \(2\log n\),注意有个常数。
仍然采用矩阵维护,维护方法类似 LCT,只维护实边信息,虚边一路跳到根更新 \(g\)。
如果我们要查询某个子树的答案怎么办?
首先找到这个节点在全局平衡二叉树上所在的二叉树。
考虑到全局平衡二叉树上每一个由实边连接的二叉树都是一条重链,并且先序遍历就是这条重链,根据之前重链剖分时的思路,我们要求的是一个点到重链底端的矩阵积。
就相当于我们在二叉排序树上查询序列后缀积,这个随便写写就好了。
Code
#include <bits/stdc++.h>
using namespace std;
typedef double db;
typedef long long LL;
#define fi first
#define se second
#define pb push_back
#define mkp make_pair
#define rep(i, x, y) for(int i = x, i##end = y; i <= i##end; ++i)
#define per(i, x, y) for(int i = x, i##end = y; i >= i##end; --i)
template<class T> inline bool ckmax(T &x, T y) { return x < y ? x = y, 1 : 0; }
template<class T> inline bool ckmin(T &x, T y) { return x > y ? x = y, 1 : 0; }
inline int read() {
int x = 0, f = 1; char ch = getchar();
while(!isdigit(ch)) { if(ch == '-') f = 0; ch = getchar(); }
while(isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
return f ? x : -x;
}
const int N = 1000005;
const int inf = 0x3f3f3f3f;
int n, m, a[N], f[N][2], g[N][2];
vector<int> e[N];
namespace Tree {
int siz[N], fa[N], son[N];
void dfs(int u, int ft) {
siz[u] = 1, fa[u] = ft;
f[u][1] = a[u];
for(int v : e[u]) if(v != ft) {
dfs(v, u), siz[u] += siz[v];
if(siz[v] > siz[son[u]]) son[u] = v;
f[u][0] += max(f[v][0], f[v][1]);
f[u][1] += f[v][0];
}
g[u][0] = f[u][0] - max(f[son[u]][0], f[son[u]][1]);
g[u][1] = f[u][1] - f[son[u]][0];
}
}
struct Matrix {
int a[2][2];
Matrix(){ memset(a, -0x3f, sizeof a); }
inline int* operator [](const int &k) { return a[k]; }
inline Matrix operator * (const Matrix &t) const {
Matrix res;
res.a[0][0] = max(a[0][0] + t.a[0][0], a[0][1] + t.a[1][0]);
res.a[0][1] = max(a[0][0] + t.a[0][1], a[0][1] + t.a[1][1]);
res.a[1][0] = max(a[1][0] + t.a[0][0], a[1][1] + t.a[1][0]);
res.a[1][1] = max(a[1][0] + t.a[0][1], a[1][1] + t.a[1][1]);
return res;
}
void print() {
cerr << a[0][0] << ' ' << a[0][1] << '\n' << a[1][0] << ' ' << a[1][1] << '\n';
}
};
namespace bst {
int fa[N], ch[N][2], stk[N], top, tsz[N], rt;
bool isrt[N];
Matrix val[N], sum[N];
inline void pushup(int u) {
sum[u] = val[u];
if(ch[u][0]) sum[u] = sum[ch[u][0]] * sum[u];
if(ch[u][1]) sum[u] = sum[u] * sum[ch[u][1]];
}
inline int build2(int l, int r) {
if(l > r) return 0;
int ALL = 0, now = 0;
rep(i, l, r) ALL += tsz[i];
rep(i, l, r) {
now += tsz[i];
if(now << 1 >= ALL) {
int u = stk[i];
fa[ch[u][0] = build2(l, i - 1)] = u;
fa[ch[u][1] = build2(i + 1, r)] = u;
return pushup(u), u;
}
}
assert(0);
}
int build(int tp) {
for(int i = tp; i; i = Tree::son[i]) {
for(int v : e[i]) if(v != Tree::fa[i] && v != Tree::son[i])
fa[build(v)] = i;
val[i][0][0] = val[i][0][1] = g[i][0];
val[i][1][0] = g[i][1], val[i][1][1] = -inf;
}
top = 0;
for(int i = tp; i; i = Tree::son[i])
stk[++top] = i, tsz[top] = Tree::siz[i] - Tree::siz[Tree::son[i]];
int tmp = build2(1, top);
isrt[tmp] = 1;
return tmp;
}
void modify(int x, int y) {
val[x][1][0] += y - a[x], a[x] = y;
for(int i = x; i; i = fa[i]) {
Matrix pre = sum[i];
pushup(i);
Matrix suf = sum[i];
if(isrt[i] && fa[i]) {
int f = fa[i];
val[f][0][0] += max(suf[0][0], suf[1][0]) - max(pre[0][0], pre[1][0]);
val[f][0][1] = val[f][0][0];
val[f][1][0] += suf[0][0] - pre[0][0];
}
}
}
}
signed main() {
n = read(), m = read();
rep(i, 1, n) a[i] = read();
rep(i, 2, n) {
int x = read(), y = read();
e[x].pb(y), e[y].pb(x);
}
Tree::dfs(1, 0);
bst::rt = bst::build(1);
int lastans = 0;
while(m--) {
int x = read() ^ lastans, y = read();
bst::modify(x, y);
printf("%d\n", lastans = max(bst::sum[bst::rt][0][0], bst::sum[bst::rt][1][0]));
}
return 0;
}
到现在为止三种维护ddp常用的方法都已经介绍完毕,用哪种请读者自己选择。
从我个人角度不建议写树剖,因为复杂度多一只 \(\log\),可能被卡。而且其实树剖+线段树是三种写法里码量最大的。
给出我的实现下,这两道题在洛谷评测的最大测试点用时:
P4719(普通版) | P4751(加强版) | |
---|---|---|
树剖+线段树 | 179ms | >3.7s(TLE) |
LCT | 77ms | 2.95s |
全局平衡二叉树 | 55ms | 1.42s |
毕竟每个人的实现都会有偏差,但是总体是可以看出每种方法的常数差别,在加强版体现的尤为突出。
小总结
上面解题的步骤其实是比较清晰的,也是一般做 ddp 题的步骤:
-
写出不修改情况下的状态转移方程
-
分离轻重儿子的贡献
-
把转移写成矩阵
-
大力码码码
一般都会在最开始暴力跑一趟树形dp求出不包括重儿子的答案塞进矩阵。
修改一般采用的方法是,消去原贡献,加入新贡献。
查询只要理解ddp本质都没问题。
例题二
小清新题,和模板没太大区别。
题意:给一棵树,每个点有点权 \(a_i\),每次询问:在某个子树内以点权为代价删除(堵上)一些点使得根与子树内所有叶子不连通的最小代价;带单点修改。
不带修情况
设 \(f_i\) 表示把以 \(i\) 为根的子树完全堵上的答案。
转移简洁并且单点修改使我们想到使用 ddp 来维护。
分离轻重儿子(这里的 \(son(u)\) 表示 \(u\) 的重儿子):
其中 \(g_u\) 是轻儿子的贡献
把转移方程写成矩阵
重定义广义矩阵乘法:
构造矩阵:
如果矩阵是一维的
那么 \(x=g_u,y=a_u-f_{son(u)}\)
发现左矩阵做了一个重儿子的东西,不能做ddp。
考虑再加一维:
因为 \(a_u=a_u+0\) ,考虑直接把 \(p,q\) 设成 \(0\) , 那么根据转移方程,\(x=g_u,y=a_u\) ,这样 \(f_u\) 已经被正确表示了。
但是底下的 \(q\) 在矩乘之后不一定是 \(0\) ,考虑通过 \(z,w\) 来维护 \(q\) 。
直接展开 \(q\) ,\(q=\min(z+f_{son(u)},w)\) (\(p=0\) 就不写了)
\(f_{son(u)}\) 是非负的,所以 \(w=0,z\ge 0\) 即可
矩阵构造完毕!
封死一颗子树的代价就是 \(f_u\)。
这里要提一个小细节,就是叶子节点没有轻儿子的时候 \(g\) 怎么办。
我想到了两种处理方法:
一种处理方法是把 \(g_u\) 设为 \(a_u\),因为叶子节点在ddp的时候要满足 \(f_u=g_u\)。
还有一种方法就是把 \(g_u\) 设成 \(+\infty\),直接禁止从“封死所有子树”这种方法的转移。并且矩阵左上角不和自己的 \(a\) 取 \(\min\),只维护轻子树的 dp值 和,调用 dp 值的时候再和自己的 \(a\) 取 \(\min\)。
第一种在树剖的时候比较好搞;如果用 LCT 维护我暂时没想到什么好的维护方法所以用了第二种。
修改
把 \(x\) 加 \(v\)
直接影响的就是这个节点的 \(a_u\)。
但是叶子节点还得同时更改 \(g_u\),千万别忘。
至于轻边父亲的修改,减去原贡献加上新贡献就好了。
LCT 同理,不过在修改一个节点矩阵的时候要先 access
再 splay
,这时候修改它对于任何节点都是没有影响的。
查询
树剖直接用线段树把这个点到重链底端的矩阵全部乘起来就好了。
LCT 比较特殊。要把这个节点在 原树 上的父亲 access
一下,再 splay
这个节点,这样子这个节点的信息就是这颗子树的信息。
我access在lct上的父亲调了一晚上(((
矩乘只是 \(2\times2\times2\) ,建议手动暴力展开,可以快非常多。
因为树剖代码是在之前的代码上改的,怕有些地方与描述不同,因此重写了一份 LCT 的代码
树剖版代码
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef double db;
#define pb(x) push_back(x)
#define mkp(x,y) make_pair(x,y)
//#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
//char buf[1<<21],*p1=buf,*p2=buf;
inline int read() {
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)) {if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch))x=x*10+(ch^48),ch=getchar();
return x*f;
}
int rdc(){
char ch=getchar();
while(ch!='Q'&&ch!='C')ch=getchar();
return ch=='Q';
}
const int N=200005;
const LL inf=1e14;
const int T=N<<2;
int n,dp[N];
LL a[N];
int head[N],num_edge;
int dfn[N],rev[N],tmr,fa[N],siz[N],son[N],top[N],ed[N];
struct edge{
int nxt,to;
}e[N<<1];
void addedge(int fr,int to){
++num_edge;
e[num_edge].nxt=head[fr];
e[num_edge].to=to;
head[fr]=num_edge;
}
struct Matrix {
LL a[2][2];
Matrix(){a[0][0]=a[0][1]=a[1][0]=a[1][1]=inf;}
LL*operator[](const int&k){return a[k];}
Matrix operator * (const Matrix&b){
Matrix res;
res.a[0][0] = min(a[0][0]+b.a[0][0],a[0][1]+b.a[1][0]);
res.a[0][1] = min(a[0][0]+b.a[0][1],a[0][1]+b.a[1][1]);
res.a[1][0] = min(a[1][0]+b.a[0][0],a[1][1]+b.a[1][0]);
res.a[1][1] = min(a[1][0]+b.a[0][1],a[1][1]+b.a[1][1]);
return res;
}
void print(){
printf("%lld %lld\n%lld %lld\n\n",a[0][0],a[0][1],a[1][0],a[1][1]);
}
}mat[N],val[T];
void dfs1(int u,int ft){
if(!e[head[u]].nxt)return dp[u]=a[u],siz[u]=1,void();
LL sum=0;siz[u]=1;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;if(v==ft)continue;
fa[v]=u,dfs1(v,u),sum+=dp[v],siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
dp[u]=min(sum,a[u]);
}
void dfs2(int u,int tp){
top[u]=tp,dfn[u]=++tmr,rev[tmr]=u,ed[tp]=tmr;
mat[u][0][0]=0,mat[u][0][1]=a[u],
mat[u][1][0]=0,mat[u][1][1]=0;
if(!son[u])return mat[u][0][0]=a[u],void();
dfs2(son[u],tp);
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa[u]||v==son[u])continue;
dfs2(v,v),mat[u][0][0]+=dp[v];
}
}
#define lc (p<<1)
#define rc (p<<1|1)
void pushup(int p){val[p]=val[lc]*val[rc];}
void build(int l,int r,int p=1){
if(l==r)return val[p]=mat[rev[l]],void();
int mid=(l+r)>>1;
build(l,mid,lc),build(mid+1,r,rc);
pushup(p);
}
Matrix query(int ql,int qr,int l=1,int r=n,int p=1){
if(ql<=l&&r<=qr)return val[p];
int mid=(l+r)>>1;
if(qr<=mid)return query(ql,qr,l,mid,lc);
if(mid<ql)return query(ql,qr,mid+1,r,rc);
return query(ql,qr,l,mid,lc)*query(ql,qr,mid+1,r,rc);
}
void change(int pos,int l=1,int r=n,int p=1){
if(l==r)return val[p]=mat[rev[pos]],void();
int mid=(l+r)>>1;
if(pos<=mid)change(pos,l,mid,lc);
else change(pos,mid+1,r,rc);
pushup(p);
}
void update(int x,int v){
mat[x][0][1]+=v,a[x]+=v;
if(siz[x]==1)mat[x][0][0]+=v;
while(x){
Matrix lst=query(dfn[top[x]],ed[top[x]]);
change(dfn[x]);
Matrix now=query(dfn[top[x]],ed[top[x]]);
x=fa[top[x]];
mat[x][0][0]+=now[0][0]-lst[0][0];
}
}
signed main(){
n=read();
for(int i=1;i<=n;++i)a[i]=read();
for(int i=1;i<n;++i){
int x=read(),y=read();
addedge(x,y),addedge(y,x);
}
dfs1(1,0),dfs2(1,1),build(1,n);
for(int m=read();m;--m){
int opt=rdc(),x=read();
if(opt){
Matrix t=query(dfn[x],ed[top[x]]);
printf("%lld\n",t[0][0]);
}
else update(x,read());
}
return 0;
}
LCT 版代码
#include <bits/stdc++.h>
using namespace std;
typedef double db;
typedef long long LL;
#define fi first
#define se second
#define pb push_back
#define mkp make_pair
#define sz(v) (int)(v).size()
#define rep(i, x, y) for(int i = x, i##end = y; i <= i##end; ++i)
#define per(i, x, y) for(int i = x, i##end = y; i >= i##end; --i)
template<class T> inline bool ckmax(T &x, T y) { return x < y ? x = y, 1 : 0; }
template<class T> inline bool ckmin(T &x, T y) { return x > y ? x = y, 1 : 0; }
inline int read() {
int x = 0, f = 1; char ch = getchar();
while(!isdigit(ch)) { if(ch == '-') f = 0; ch = getchar(); }
while(isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
return f ? x : -x;
}
inline int rdch() {
char ch = getchar();
while(ch != 'Q' && ch != 'C') ch = getchar();
return ch == 'Q';
}
const int N = 200005;
const LL inf = 1e14;
int n, m, lef[N], tfa[N];
LL a[N], dp[N];
vector<int> e[N];
int fa[N], ch[N][2];
struct Matrix {
LL a[2][2];
Matrix() { a[0][0] = a[0][1] = a[1][0] = a[1][1] = inf; }
inline LL* operator [](const int &k) { return a[k]; }
inline Matrix operator * (const Matrix &t) const {
Matrix res;
res.a[0][0] = min(a[0][0] + t.a[0][0], a[0][1] + t.a[1][0]);
res.a[0][1] = min(a[0][0] + t.a[0][1], a[0][1] + t.a[1][1]);
res.a[1][0] = min(a[1][0] + t.a[0][0], a[1][1] + t.a[1][0]);
res.a[1][1] = min(a[1][0] + t.a[0][1], a[1][1] + t.a[1][1]);
return res;
}
} val[N], sum[N];
inline bool nroot(int x) { return ch[fa[x]][0] == x || ch[fa[x]][1] == x; }
inline void pushup(int x) {
sum[x] = val[x];
if(ch[x][0]) sum[x] = sum[ch[x][0]] * sum[x];
if(ch[x][1]) sum[x] = sum[x] * sum[ch[x][1]];
}
inline void rotate(int x) {
int y = fa[x], z = fa[y], k = ch[y][1] == x, w = ch[x][!k];
if(nroot(y)) ch[z][ch[z][1] == y] = x;
ch[x][!k] = y, ch[y][k] = w;
fa[w] = y, fa[y] = x, fa[x] = z;
pushup(y);
}
inline void splay(int x) {
while(nroot(x)) {
int y = fa[x], z = fa[y];
if(nroot(y)) rotate((ch[z][1] == y) ^ (ch[y][1] == x) ? x : y);
rotate(x);
}
pushup(x);
}
inline void access(int x) {
for(int y = 0; x; x = fa[y = x]) {
splay(x);
if(y) val[x][0][0] -= min(sum[y][0][0], sum[y][0][1]);
if(ch[x][1]) val[x][0][0] += min(sum[ch[x][1]][0][0], sum[ch[x][1]][0][1]);
ch[x][1] = y, pushup(x);
}
}
void dfs(int u, int ft) {
lef[u] = 1;
for(int v : e[u]) if(v != ft)
tfa[v] = fa[v] = u, dfs(v, u), dp[u] += dp[v], lef[u] = 0;
if(lef[u]) dp[u] = a[u];
val[u][0][0] = lef[u] ? inf : dp[u], val[u][0][1] = a[u];
val[u][1][0] = val[u][1][1] = 0;
pushup(u);
ckmin(dp[u], a[u]);
}
signed main() {
n = read();
rep(i, 1, n) a[i] = read();
rep(i, 2, n) {
int x = read(), y = read();
e[x].pb(y), e[y].pb(x);
}
dfs(1, 0);
for(m = read(); m--; ) {
int op = rdch(), x = read();
if(op) {
if(tfa[x]) access(tfa[x]);
splay(x), printf("%lld\n", min(sum[x][0][0], sum[x][0][1]));
} else {
int y = read();
access(x), splay(x);
val[x][0][1] += y;
pushup(x);
}
}
}
例题三
分别钦定两个城市必取或者必不取的最小独立集。
对于一定驻扎,把点权设为 \(-\infty\)。对于一定不驻扎,点权设为 \(+\infty\)。
然后跑最小独立集即可。
最后输出的时候加或减一下之前偏移的 \(\infty\)。
可以偷懒把点权取反拉最大独立集的板子(
但是这里修改带四倍常数,用树剖写时间非常紧,uoj 上根本过不去,建议写全局平衡二叉树,我懒得重写了。
Code
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef double db;
#define pb(x) push_back(x)
#define mkp(x,y) make_pair(x,y)
//#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
//char buf[1<<21],*p1=buf,*p2=buf;
inline int read() {
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)) {if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch))x=x*10+(ch^48),ch=getchar();
return x*f;
}
const int N=100005;
const int M=N<<2;
const LL inf=1e12;
int n,m;
LL p[N];
int siz[N],dfn[N],tmr,son[N],fa[N],top[N],rev[N],ed[N],f[N][2];
char cynAKIOI[114514];
struct edge{
int nxt,to;
}e[N<<1];
int head[N],num_edge;
void addedge(int fr,int to){
++num_edge;
e[num_edge].nxt=head[fr];
e[num_edge].to=to;
head[fr]=num_edge;
}
struct Matrix{
LL p[2][2];
Matrix(){p[0][0]=p[0][1]=p[1][0]=p[1][1]=-inf;}
LL*operator[](const int&k){return p[k];}
Matrix operator * (const Matrix&b){
Matrix res;
// for(int i=0;i<2;++i)
// for(int j=0;j<2;++j)
// for(int k=0;k<2;++k)
// res.p[i][j]=max(res.p[i][j],p[i][k]+b.p[k][j]);
res[0][0]=max(p[0][0]+b.p[0][0],p[0][1]+b.p[1][0]);
res[0][1]=max(p[0][0]+b.p[0][1],p[0][1]+b.p[1][1]);
res[1][0]=max(p[1][0]+b.p[0][0],p[1][1]+b.p[1][0]);
res[1][1]=max(p[1][0]+b.p[0][1],p[1][1]+b.p[1][1]);
return res;
}
}mat[N],val[M];
void dfs1(int u,int ft){
siz[u]=1,f[u][1]=p[u];
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;if(v==ft)continue;
fa[v]=u,dfs1(v,u),siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
f[u][0]+=max(f[v][0],f[v][1]);
f[u][1]+=f[v][0];
}
}
void dfs2(int u,int tp){
top[u]=tp,dfn[u]=++tmr,rev[tmr]=u,ed[tp]=tmr;
if(son[u])dfs2(son[u],tp);
LL g[2];g[0]=0,g[1]=p[u];
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==son[u]||v==fa[u])continue;
dfs2(v,v);
g[0]+=max(f[v][0],f[v][1]);
g[1]+=f[v][0];
}
mat[u][0][0]=g[0],mat[u][0][1]=g[0];
mat[u][1][0]=g[1],mat[u][1][1]=-inf;
}
#define lc (p<<1)
#define rc (p<<1|1)
void pushup(int p){val[p]=val[lc]*val[rc];}
void build(int l,int r,int p){
if(l==r)return val[p]=mat[rev[l]],void();
int mid=(l+r)>>1;
build(l,mid,lc),build(mid+1,r,rc);
pushup(p);
}
Matrix query(int ql,int qr,int l=1,int r=n,int p=1){
if(ql<=l&&r<=qr)return val[p];
int mid=(l+r)>>1;
if(qr<=mid)return query(ql,qr,l,mid,lc);
if(mid<ql)return query(ql,qr,mid+1,r,rc);
return query(ql,qr,l,mid,lc)*query(ql,qr,mid+1,r,rc);
}
void change(int pos,int l=1,int r=n,int p=1){
if(l==r)return val[p]=mat[rev[l]],void();
int mid=(l+r)>>1;
if(pos<=mid)change(pos,l,mid,lc);
else change(pos,mid+1,r,rc);
pushup(p);
}
void update(int x,LL v){
mat[x][1][0]+=v,p[x]+=v;
while(x){
Matrix lst=query(dfn[top[x]],ed[top[x]]);
change(dfn[x]);
Matrix now=query(dfn[top[x]],ed[top[x]]);
x=fa[top[x]];
mat[x][0][0]+=max(now[0][0],now[1][0])-max(lst[0][0],lst[1][0]);
mat[x][0][1]=mat[x][0][0];
mat[x][1][0]+=now[0][0]-lst[0][0];
}
}
signed main(){
n=read(),m=read(),scanf("%s",cynAKIOI);
for(int i=1;i<=n;++i)p[0]+=(p[i]=read());
for(int i=1;i<n;++i){
int x=read(),y=read();
addedge(x,y),addedge(y,x);
}
dfs1(1,0),dfs2(1,1),build(1,n,1);
while(m--){
int a=read(),x=read(),b=read(),y=read();
LL ad1=x?-inf:inf,out1=x?0:inf;
LL ad2=y?-inf:inf,out2=y?0:inf;
update(a,ad1),update(b,ad2);
Matrix res=query(dfn[1],ed[1]);
LL out=p[0]-max(res[0][0],res[1][0])+out1+out2;
out<inf?printf("%lld\n",out):puts("-1");
update(a,-ad1),update(b,-ad2);
}
return 0;
}
例题四
先考虑不带修改的情况如何dp。
设 \(dp(u,msk)\) 表示以 \(u\) 为根的联通子树 \(\operatorname{xor}\) 起来为 \(msk\) 的方案数。
转移是:
暴力转移是 \(O(m^2)\),非常明显可以 FWT 优化成 \(O(m\log m)\)。
统计答案的时候,假设询问 \(k\),那么就是 \(\sum_{i=1}^{n} dp(i,k)\)。
我觉得这里还是有必要提一下暴力 dp 的边界以及转移的细节。
我一开始写的边界处理是:\(dp(u,w_u)=1\),在把所有孩子合并上来之后再给 \(dp(u,0)\) 加一,这样它父亲调用它的时候那个 \(0\) 就相当于不选自己。
凭直觉就知道在这种鬼地方多个类似 if
的东西可能非常难办。以及 FWT 和 IFWT 的位置可能影响我们维护修改的难度。还有我们统计答案的方式是遍历所有节点而非在单一节点统计答案。这些问题在一开始就得解决。
以下记 \(\hat{a}\) 表示 \(a\) FWT 之后的数组。
首先解决统计答案的问题。
考虑记 \(g(u,msk)=f(u,msk)+\sum_{v\in_{son(u)}}f(v,msk)\)。
这样子我们调用 \(g(1,msk)\) 就能得到整颗树的答案了。
接下去看怎么把转移写简洁。
最后单独给 \(0\) 加一肯定要去掉。那么在转移方程后加一项就行了
FWT 之后有
于是这个转移可以写的非常简洁:
\(\hat{w}(u)\) 表示这个点的点权 FWT 之后的序列。
考虑 \(g\) 怎么搞。如果 IFWT 回去再统计又会使转移非常麻烦。
注意到点值是可以直接加的,那不妨维护 \(\hat{g}\),最后 IFWT 回去输出答案。
现在转移方程非常简洁了,只不过复杂度是 \(O(qnm\log m)\),考虑怎么优化。
注意以下的 \(f,g\) 全部定义为多项式,乘法定义为按位相乘。
考虑 ddp。分离轻重儿子:
记
那么 dp 就写成了下面的形式
然后就构造矩阵转移
修改的时候只需要跳重链修改 \(\hat{F},\hat{G}\) 就好,就是消去原贡献加入现在的贡献。
但是 \(F\) 消贡献是除掉一个东西,并且 XOR 的 FWT 是可以出负数的,加上模数非常小,很有可能除一个 \(0\) 下去(样例就是),看起来非常棘手。
事实上这个处理非常简单:对于每个节点开桶存乘了几个 \(0\),除以 \(0\) 的时候操作桶就行了。
复杂度是 \(O(qm(\log n+\log m))\)。
但是矩阵乘法带 \(27\) 常数,全局平衡二叉树带 \(2\) 倍常数,带进去一算是惊人的 1e9,加上大量封装,根本过不去。
这时候有个小 trick,有些矩阵矩乘之后常数不变,这个矩阵也是这样。
于是只用维护四个值,常数就从 \(27\) 降到了 \(8\)!
到此为止思路结束了,码代码就靠自己了(逃
不过这题别写树剖,多个 \(\log\) 运算量差不多是 1e9,加上洛谷有个毒瘤加了组对着树剖卡的数据,基本不用想过。
Code
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mkp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define sz(v) (int)v.size()
typedef long long LL;
typedef double db;
template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
#define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
#define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
inline int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return f?x:-x;
}
inline int rdch() {
char ch = getchar();
while(ch != 'Q' && ch != 'C') ch = getchar();
return ch == 'Q';
}
const int N = 30005;
const int mod = 10007;
const int iv2 = (mod + 1) >> 1;
int inv[N];
int n, m, w[N];
inline int qpow(int n, int k) {
int res = 1;
for(; k; k >>= 1, n = n * n % mod)
if(k & 1) res = res * n % mod;
return res;
}
inline int add(int x, int y) { return (x += y) >= mod ? x - mod : x; }
inline int sub(int x, int y) { return (x -= y) < 0 ? x + mod : x; }
struct pint {
int v, c;
pint() { v = 1, c = 1; }
pint(int v_) {
if(!v_) v = 1, c = 1;
else v = v_, c = 0;
}
inline int val() const { return c ? 0 : v; }
friend pint operator * (pint a, const int &b) {
if(!b) return ++a.c, a;
else return (a.v *= b) %= mod, a;
}
friend pint operator / (pint a, const int &b) {
if(!b) return --a.c, a;
else return (a.v *= inv[b]) %= mod, a;
}
};
inline vector<int> change(const vector<pint> &a) {
vector<int> res(m);
for(int i = 0; i < m; ++i) res[i] = a[i].val();
return res;
}
inline vector<pint> operator * (const vector<pint> &a, const vector<int> &b) {
vector<pint> res(m);
for(int i = 0; i < m; ++i) res[i] = a[i] * b[i];
return res;
}
inline vector<pint> operator / (const vector<pint> &a, const vector<int> &b) {
vector<pint> res(m);
for(int i = 0; i < m; ++i) res[i] = a[i] / b[i];
return res;
}
inline vector<int> operator + (const vector<int> &a, const vector<int> &b) {
vector<int> res(m);
for(int i = 0; i < m; ++i) res[i] = add(a[i], b[i]);
return res;
}
inline vector<int> operator - (const vector<int> &a, const vector<int> &b) {
vector<int> res(m);
for(int i = 0; i < m; ++i) res[i] = sub(a[i], b[i]);
return res;
}
inline vector<int> operator * (const vector<int> &a, const vector<int> &b) {
vector<int> res(m);
for(int i = 0; i < m; ++i) res[i] = a[i] * b[i] % mod;
return res;
}
inline vector<int> addone(vector<int> a) {
for(int i = 0; i < m; ++i) a[i] = add(a[i], 1);
return a;
}
inline vector<int> XOR(vector<int> a) {
for(int i = 1; i < m; i <<= 1)
for(int j = 0; j < m; j += i << 1)
for(int k = 0; k < i; ++k) {
int X = a[j + k], Y = a[i + j + k];
a[j + k] = add(X, Y), a[i + j + k] = sub(X, Y);
}
return a;
}
inline vector<int> IXOR(vector<int> a) {
for(int i = 1; i < m; i <<= 1)
for(int j = 0; j < m; j += i << 1)
for(int k = 0; k < i; ++k) {
int X = a[j + k], Y = a[i + j + k];
a[j + k] = (X + Y) * iv2 % mod, a[i + j + k] = (X - Y + mod) * iv2 % mod;
}
return a;
}
int rt, tfa[N], fa[N], cnz[N], siz[N], son[N], stk[N], top, ch[N][2], tsz[N];
bool isrt[N];
vector<int> e[N];
struct Matrix {
vector<int> a00, a10, a02, a12;
inline Matrix operator * (const Matrix &t) const {
Matrix res;
res.a00 = a00 * t.a00;
res.a10 = a10 * t.a00 + t.a10;
res.a02 = a00 * t.a02 + a02;
res.a12 = a10 * t.a02 + a12 + t.a12;
return res;
}
} val[N], sum[N];
Matrix mat;
vector<pint> F[N];
vector<int> G[N], ans, f[N], g[N], a[N];
inline void pushup(int u) {
sum[u] = val[u];
if(ch[u][0]) sum[u] = sum[ch[u][0]] * sum[u];
if(ch[u][1]) sum[u] = sum[u] * sum[ch[u][1]];
}
inline void get(int u) {
val[u].a00 = val[u].a10 = val[u].a02 = val[u].a12 = change(F[u]);
val[u].a12 = val[u].a12 + G[u];
}
void dfs(int u, int ft) {
f[u].resize(m), g[u].resize(m);
f[u][w[u]] = 1, f[u] = XOR(f[u]);
a[u] = f[u];
siz[u] = 1;
for(int v : e[u]) if(v != ft) {
tfa[v] = u, dfs(v, u), siz[u] += siz[v];
if(siz[v] > siz[son[u]]) son[u] = v;
f[u] = f[u] * addone(f[v]), g[u] = g[u] + g[v];
}
g[u] = g[u] + f[u];
F[u].resize(m), G[u].resize(m);
for(int i = 0; i < m; ++i) F[u][i] = a[u][i];
for(int v : e[u]) if(v != ft && v != son[u]) {
F[u] = F[u] * addone(f[v]), G[u] = G[u] + g[v];
}
get(u);
}
inline int build2(int l, int r) {
if(l > r) return 0;
int ALL = 0, now = 0;
for(int i = l; i <= r; ++i) ALL += tsz[i];
for(int i = l; i <= r; ++i) {
now += tsz[i];
if(now << 1 >= ALL) {
int u = stk[i];
fa[ch[u][0] = build2(l, i - 1)] = u;
fa[ch[u][1] = build2(i + 1, r)] = u;
return pushup(u), u;
}
}
return -1;
}
int build(int tp) {
for(int i = tp; i; i = son[i])
for(int v : e[i]) if(v != son[i] && v != tfa[i])
fa[build(v)] = i;
top = 0;
for(int i = tp; i; i = son[i]) stk[++top] = i, tsz[top] = siz[i] - siz[son[i]];
int tmp = build2(1, top);
return isrt[tmp] = 1, tmp;
}
void modify(int x, int y) {
F[x] = F[x] / a[x];
memset(a[x].data(), 0, m << 2);
a[x][y] = 1, w[x] = y, a[x] = XOR(a[x]);
F[x] = F[x] * a[x], get(x);
for(; x; x = fa[x]) {
if(fa[x] && isrt[x]) {
F[fa[x]] = F[fa[x]] / addone(sum[x].a02), G[fa[x]] = G[fa[x]] - sum[x].a12;
pushup(x);
F[fa[x]] = F[fa[x]] * addone(sum[x].a02), G[fa[x]] = G[fa[x]] + sum[x].a12;
get(fa[x]);
} else pushup(x);
}
}
signed main() {
inv[1] = 1;
for(int i = 2; i < mod; ++i) inv[i] = inv[mod % i] * (mod - mod / i) % mod;
n = read(), m = read();
rep(i, 1, n) w[i] = read();
rep(i, 2, n) {
int x = read(), y = read();
e[x].pb(y), e[y].pb(x);
}
dfs(1, 0);
rt = build(1);
ans = IXOR(sum[rt].a12);
for(int q = read(); q; --q) {
int op = rdch(), x = read();
if(op == 1) {
printf("%d\n", ans[x]);
} else {
int y = read();
modify(x, y);
ans = IXOR(sum[rt].a12);
}
}
return 0;
}