高级算法指北——浅谈整体dp

简述整体dp

整体 dp 是一种优化 dp 的方式。它通过类似数据结构维护序列的方式将 dp 状态中的一维压入数据结构,并通过少量单点或区间的操作达到对该维所有状态进行转移的目的,从而将一维 \(O(n)\) 的状态带来的至少 \(O(n)\) 的时间复杂度减少至 \(O(\log n)\) 的级别。

不难发现,整体 dp 是基于 \(O(1)\) 转移的 dp 的再优化。当 dp 状态数过多,而限制了程序运行效率时,若某一维(通常是后一维)的转移具有较强的共性时,可以考虑利用整体 dp 优化。

一般地,使用整体 dp 的题目有以下几个步骤:

  1. 写出朴素的 dp

  2. 将朴素 dp 使用前缀和等方式优化至 \(O(1)\) 转移(如果本来是 \(O(1)\) 的一对一转移就不需要了)

  3. 找出 dp 状态中转移具有共性的一维,使用数据结构维护这一维。具体地,随着其它维的变化,在数据结构上执行各种修改操作,动态维护此时此刻,当压进数据结构一维的下标为某个值时的 dp 值。

下面用若干例题说明整体 dp 的过程。

经典例题

谈到数据结构,首先想到的应该是线段树。下面先看几道用线段树维护整体 dp 的例子。

P01 整体 dp 板子题:Luogu P9400 「DBOI」Round 1 三班不一般

简要题意:构造一个长度为 \(n\) 的序列,序列上第 \(i\) 个元素的值域范围为 \([l_i,r_i]\),求有多少个序列使得不存在连续的 \(a\) 个元素的值大于 \(b\)

解法:
考虑 dp,设 \(dp_{i,j}\) 表示考虑到第 \(i\) 个位置,尾端有连续 \(j\) 个值大于 \(b\) 的方案数。设 \(val\) 为对应情况下 \(i\)
处可以填的数的数量,可以简单求出。则有:

\[dp_{i,0}=\sum_{j=0}^{a-1} dp_{i-1,j}\times val_1 \]

\[dp_{i,j}=dp_{i-1,j-1}\times val_2 \ (j>0) \]

对于 \(\sum_{j=0}^{a-1} dp_{i-1,j}\) 的部分可以提前求和做到 \(O(1)\) 转移。

发现已经无法继续优化转移,且每个 \(j\) 都是从 \(j-1\) 处转移过来的,于是考虑整体 dp。不难发现我们需要一个支持 单点插入、单点删除、区间位移、区间乘、区间求和 的数据结构。平衡树显然可以做,这里提供一个线段树的做法:不难发现序列长度始终为 \(a\),位移最多 \(n\) 次,因此我们用线段树维护一个长度为 \(n+a\) 的序列,并维护当前有效的长度为 \(a\) 的区间,每次先区间求和算出新的 \(dp_{i,0}\),然后将区间向左移动一位,将左侧新增的那个位置的值设为 \(dp_{i,0}\),最后给剩余部分区间乘即可。

下面是 AC 代码。可以结合代码细节理解。

int n,a,b;
struct seg{
	int t[4*N],tag[4*N];
	void pushdown(int x){
		if(tag[x]==-1)return; 
		if(tag[ls(x)]==-1)tag[ls(x)]=1;
		if(tag[rs(x)]==-1)tag[rs(x)]=1;
		t[ls(x)]*=tag[x],t[rs(x)]*=tag[x];tag[ls(x)]*=tag[x],tag[rs(x)]*=tag[x];//利用乘法结合律与分配律 
		t[ls(x)]%=mo,t[rs(x)]%=mo,tag[ls(x)]%=mo,tag[rs(x)]%=mo,tag[x]=-1;
	}
	void pushup(int x){
		t[x]=(t[ls(x)]+t[rs(x)])%mo;
	}
	void modify(int x,int le,int ri,int p,int v){
		if(le==ri){
			t[x]=v,tag[x]=-1;
			return;
	    }
	    pushdown(x);
	    int mid=(le+ri)>>1;
	    if(p<=mid)modify(ls(x),le,mid,p,v);
		else modify(rs(x),mid+1,ri,p,v);
		pushup(x);
	}
	void mult(int x,int le,int ri,int ql,int qr,int v){
		if(ql<=le&&qr>=ri){
			if(tag[x]==-1)tag[x]=1;
			t[x]=t[x]*v%mo,tag[x]=tag[x]*v%mo;
			return;
		}
		pushdown(x);
		int mid=(le+ri)>>1;
		if(ql<=mid)mult(ls(x),le,mid,ql,qr,v);
		if(qr>mid)mult(rs(x),mid+1,ri,ql,qr,v);
		pushup(x);
	}
	int query(int x,int le,int ri,int ql,int qr){
		if(ql<=le&&qr>=ri)return t[x];
		pushdown(x);
		int mid=(le+ri)>>1,res=0;
		if(ql<=mid)res+=query(ls(x),le,mid,ql,qr);
		if(qr>mid)res+=query(rs(x),mid+1,ri,ql,qr);
		return res%mo;
	}
}T;
signed main(){
	memset(T.tag,-1,sizeof(T.tag));
	read(n),read(a),read(b);
	int le=n+1,ri=n+a;
	T.modify(1,1,n+a,n+1,1);
	rep(i,1,n){
		int x,y;
		read(x),read(y);
		int sum=T.query(1,1,n+a,le,ri);
		le--,ri--;//区间移动
		int val=max(0ll,y-max(b,x-1));
		if(le+1<=ri)T.mult(1,1,n+a,le+1,ri,val);//更新 dp[i][1]~dp[i][a-1]
		val=max(0ll,min(b,y)-x+1)*sum%mo;
		T.modify(1,1,n+a,le,val);//更新 dp[i][0]
    }
    printf("%lld\n",T.query(1,1,n+a,le,ri));
    return 0;
}

P02 线段树的复杂操作:Luogu P8476 「GLR-R3」惊蛰

简要题意:给定长度为 \(n\) 的序列 \(a\),你需要修改其每个位置的元素值得到序列 \(b\)。对于每个 \(1\leq i\leq n\),若 \(b_i\geq a_i\),则你需要花费 \(b_i-a_i\) 的代价;若 \(b_i<a_i\),则你需要花费 \(c\) 的代价。问使得 \(b\) 不增的最小代价。

解法:

考虑 dp。设 \(dp_{i,j}\) 表示调整完前 \(i\) 个数,最后一个数的值为 \(j\) 的时候的最小花费。不难发现调整到 \(a\) 序列中的已有值一定不劣,所以可以离散化处理。于是最朴素的 dp 转移可以设计出来:\(dp_{i,j}=\min_{k=j}^{n} dp_{i-1,k}+val\),其中 \(val\) 表示将 \(a_i\) 改成 \(j\) 的代价。我们给 \(dp_{i-1}\) 滚一个后缀 min 之后即可做到 \(O(1)\) 转移。

接下来考虑如何把第二维搬上数据结构。首先滚完后缀 min 之后的 dp 数组是自己向自己位置转移,没有位移操作,可以直接用线段树简单维护。不难发现 \(val\) 是个分段函数,于是找到分界点,即 \(a_i=j\) 的位置。对于这个位置之前的所有下标,本次操作的代价均为 \(c\),直接区间加 \(c\) 即可;对于这个位置及之后的所有下标 \(x\),每个位置需要加上 \(b_x-a_i\)。首先 \(-a_i\) 的部分可以直接区间加,而 \(b_x\) 是单调递增的,滚了后缀 min 之后的 dp 数组也是单增的,因此修改后的最小值一定取在区间的最左端点处。于是加 \(b_x\) 可以直接打 tag,对于每个有 tag 的区间,给最小值加上区间左端点下标对应的值即可。最后还需要滚一个后缀 min。不难发现,修改后 \(j<a_i\)\(j\geq a_i\) 的位置分别单调递增,于是直接找到 \(j\geq a_i\) 部分的最小值,并在 \(j<a_i\) 二分找到大于右侧最小值的部分,区间赋值抹平即可。于是最终复杂度 \(O(n\log n)\)

int n,c,a[N],lsh[N],cntl;
struct seg{
	int t[4*N],tx[4*N],tag[4*N],cntt[4*N],tagv[4*N];
	void pushup(int x){
		t[x]=min(t[ls(x)],t[rs(x)]);
		tx[x]=max(tx[ls(x)],tx[rs(x)]);
	}
	void pushdown(int x,int le,int ri){
		if(tagv[x]!=-1){
			t[ls(x)]=t[rs(x)]=tx[ls(x)]=tx[rs(x)]=tagv[ls(x)]=tagv[rs(x)]=tagv[x],tagv[x]=-1; 
			cntt[ls(x)]=cntt[rs(x)]=tag[ls(x)]=tag[rs(x)]=0;//pushdown的时候,赋值标记依然会影响其他标记 
		}
		int mid=(le+ri)>>1;
		if(cntt[x]){
		    cntt[ls(x)]+=cntt[x],cntt[rs(x)]+=cntt[x],t[ls(x)]+=cntt[x]*lsh[le],t[rs(x)]+=cntt[x]*lsh[mid+1];
			tx[ls(x)]+=cntt[x]*lsh[mid],tx[rs(x)]+=cntt[x]*lsh[ri],cntt[x]=0;
		}
		if(tag[x])tag[ls(x)]+=tag[x],tag[rs(x)]+=tag[x],t[ls(x)]+=tag[x],t[rs(x)]+=tag[x],tx[ls(x)]+=tag[x],tx[rs(x)]+=tag[x],tag[x]=0;
	}
	void add(int x,int le,int ri,int ql,int qr,int v){//区间加
		if(ql<=le&&qr>=ri){
			t[x]+=v,tx[x]+=v,tag[x]+=v;
			return;
	    }
	    pushdown(x,le,ri);
	    int mid=(le+ri)>>1;
	    if(ql<=mid)add(ls(x),le,mid,ql,qr,v);
	    if(qr>mid)add(rs(x),mid+1,ri,ql,qr,v);
	    pushup(x); 
	}
	void addi(int x,int le,int ri,int ql,int qr){//区间加下标
		if(ql<=le&&qr>=ri){
			t[x]+=lsh[le],tx[x]+=lsh[ri],cntt[x]++;
			return;
		}
		pushdown(x,le,ri);
		int mid=(le+ri)>>1;
		if(ql<=mid)addi(ls(x),le,mid,ql,qr);
		if(qr>mid)addi(rs(x),mid+1,ri,ql,qr);
		pushup(x); 
	}
	void modify(int x,int le,int ri,int ql,int qr,int v){//区间赋值
		if(ql<=le&&qr>=ri){
			tagv[x]=t[x]=tx[x]=v,cntt[x]=tag[x]=0;
			return;
		}
		pushdown(x,le,ri);
		int mid=(le+ri)>>1;
		if(ql<=mid)modify(ls(x),le,mid,ql,qr,v);
	    if(qr>mid)modify(rs(x),mid+1,ri,ql,qr,v);
	    pushup(x);
	}
	int query(int x,int le,int ri,int ql,int qr){//区间最小值
		if(ql<=le&&qr>=ri)return t[x];
		pushdown(x,le,ri);
		int mid=(le+ri)>>1,res=inf;
		if(ql<=mid)res=min(res,query(ls(x),le,mid,ql,qr));
		if(qr>mid)res=min(res,query(rs(x),mid+1,ri,ql,qr));
		return res;
	}
	int querypos(int x,int le,int ri,int v){//线段树上二分
		if(tx[x]<v)return ri+1;
		if(le==ri)return le;
		pushdown(x,le,ri);
		int mid=(le+ri)>>1;
		if(tx[ls(x)]>=v)return querypos(ls(x),le,mid,v);
		else return querypos(rs(x),mid+1,ri,v);
	}
}T;
signed main(){
	read(n),read(c);
	rep(i,1,n)
	    read(a[i]),lsh[i]=a[i];
	sort(lsh+1,lsh+n+1),cntl=unique(lsh+1,lsh+n+1)-lsh-1;
	rep(i,1,n)
	    a[i]=lower_bound(lsh+1,lsh+cntl+1,a[i])-lsh;
	memset(T.tagv,-1,sizeof(T.tagv));//注意赋值有可能赋成0,所以tagv初始值为-1. 
	rep(i,1,n){
		if(a[i]!=1)T.add(1,1,cntl,1,a[i]-1,c);
		T.add(1,1,cntl,a[i],cntl,-lsh[a[i]]),T.addi(1,1,cntl,a[i],cntl);
		int rmin=T.query(1,1,cntl,a[i],cntl),targ=T.querypos(1,1,cntl,rmin);
		if(targ<a[i])T.modify(1,1,cntl,targ,a[i]-1,rmin);
	}
	printf("%lld\n",T.t[1]);
	return 0; 
}

P03 用线段树合并支持树上整体dp:Luogu P6773 [NOI2020] 命运

简要题意:给定一棵 \(n\) 个点的树和 \(m\) 条限制,给树上的每条边赋一个 \(0\)\(1\) 的权值,对于每个限制 \((u_i,v_i)\) (满足 \(u_i\)\(v_i\) 的祖先),你需要保证 \(u_i\)\(v_i\) 的路径上至少有一条边的权值为 \(1\)。求赋值方案数。

解法:

不难发现对于从同一个点出发的若干个限制,若 \(u\) 最深的一个限制被满足,那么所有的限制均会被满足。因而我们只关心最深的那个限制。因此可以设 \(dp_{i,j}\) 表示在 \(i\) 的子树内,未被满足的限制向上延伸到最深深度为 \(j\) 的方案数,并设全部满足时 \(j=0\)。讨论 \(i\) 与其某个儿子 \(u\) 的边权情况,若为 \(1\),则到 \(j\) 的限制必须由 \(i\) 贡献,这部分的答案为 \(dp_{i,j}\times \sum_{k=0}^{dep_i}dp_{u,k}\);若为 \(0\),则深度最大值由 \(i\)\(u\) 中的任意一个贡献,注意两者都贡献的情况不要算重,这部分的答案为 \(dp_{i,j}\times \sum_{k=0}^{j}dp_{u,k}+dp_{u,j}\times \sum_{k=0}^{j-1}dp_{i,k}\)。滚一个前缀和即可做到 \(O(1)\) 转移。

考虑用线段树维护第二维,合并时直接线段树合并。对于第一部分的 \(\sum_{k=0}^{dep_i}dp_{u,k}\) 可以在线段树上区间求和得到,在合并两个节点之前先乘上去即可;第二部分的 \(\sum_{k=0}^{j}dp_{u,k}\)\(\sum_{k=0}^{j-1}dp_{i,k}\) 均与下标有关,考虑在合并线段树时先合并左子树,再合并右子树,合并过程中动态维护两个和值。具体地,合并到某个位置,就给两个和值加上对应节点上的值即可。注意由于 \(j-1\) 这个上界的存在,两个和值的加答案与合并操作的先后顺序是不同的。

int n,m;
struct edge{
	int to,nxt;
}e[2*N];
int fir[N],np,dep[N];
vector<int>op[N];
void add(int x,int y){
	e[++np]=(edge){y,fir[x]};
	fir[x]=np;
}
int sl,sr;
struct seg{//注意在线合的过程中,若一边有一边没有,会访问不到叶子节点,因而我们需要一个乘法tag. 
	int t[30*N],lson[30*N],rson[30*N],rt[N],cnt=0,tag[30*N];
	void pushup(int x){
		t[x]=(t[ls(x)]+t[rs(x)])%mo;
	}
	void pushdown(int x){
		if(tag[x]==1)return;
		tag[ls(x)]=tag[ls(x)]*tag[x]%mo,tag[rs(x)]=tag[rs(x)]*tag[x]%mo;
		t[ls(x)]=t[ls(x)]*tag[x]%mo,t[rs(x)]=t[rs(x)]*tag[x]%mo;
		tag[x]=1;
	}
	void modify(int &x,int le,int ri,int p,int v){
		if(!x)x=++cnt,tag[x]=1;
		if(le==ri){
			t[x]=v,tag[x]=1;
			return;
		}
		pushdown(x);
		int mid=(le+ri)>>1;
		if(p<=mid)modify(ls(x),le,mid,p,v);
		else modify(rs(x),mid+1,ri,p,v);
		pushup(x);
	}
	int query(int x,int le,int ri,int ql,int qr){
		if(!x)return 0;
		if(ql<=le&&qr>=ri)return t[x];
		pushdown(x);
		int mid=(le+ri)>>1,ret=0;
		if(ql<=mid)ret+=query(ls(x),le,mid,ql,qr);
		if(qr>mid)ret+=query(rs(x),mid+1,ri,ql,qr);
		return ret%mo;
    }
    int merge(int p,int q,int le,int ri){
    	if(!p&&!q)return 0;
    	if(!p){
    		sl=(sl+t[q])%mo,t[q]=t[q]*sr%mo,tag[q]=tag[q]*sr%mo;
    		return q;
    	}
    	if(!q){
    		sr=(sr+t[p])%mo,t[p]=t[p]*sl%mo,tag[p]=tag[p]*sl%mo;
    		return p;
    	}
    	if(le==ri){
    		sl=(sl+t[q])%mo,t[q]=t[q]*sr%mo,sr=(sr+t[p])%mo,t[p]=t[p]*sl%mo;
    		t[p]=(t[p]+t[q])%mo;
    		return p;
    	}
    	pushdown(p),pushdown(q);
    	int mid=(le+ri)>>1;
    	ls(p)=merge(ls(p),ls(q),le,mid),rs(p)=merge(rs(p),rs(q),mid+1,ri);
		pushup(p);
		return p; 
    }
}T;
void dfs(int x,int f){
	dep[x]=dep[f]+1;
	int maxd=0;
	rep(i,0,(int)op[x].size()-1)
	    maxd=max(maxd,dep[op[x][i]]);
	T.modify(T.rt[x],0,n,maxd,1);
	for(int i=fir[x];i;i=e[i].nxt){
		int j=e[i].to;
		if(j==f)continue;
		dfs(j,x);
		sl=T.query(T.rt[j],0,n,0,dep[x]),sr=0;
		T.rt[x]=T.merge(T.rt[x],T.rt[j],0,n);
	}
}
signed main(){
//	freopen("destiny.in","r",stdin);
//	freopen("destiny.out","w",stdout); 
	read(n);
	rep(i,1,n-1){
		int x,y;
		read(x),read(y),add(x,y),add(y,x);
	}
	read(m);
	rep(i,1,m){
		int x,y;
		read(x),read(y),op[y].push_back(x);
	}
	dfs(1,0);
	printf("%lld\n",T.query(T.rt[1],0,n,0,0));
	return 0;
}

其他数据结构,如平衡树、可并堆等,在整体 dp 过程中也有独特的优势,如平衡树支持插入和删除,可并堆空间占用小......

P04 用平衡树维护复杂的插入、删除、位移操作:CF809D Hitchhiking in the Baltic States

简要题意:构造一个长度为 \(n\) 的序列,序列上第 \(i\) 个元素的值域范围为 \([l_i,r_i]\),最大化这个序列严格上升子序列的长度。求这个长度的最大值。

解法:

模仿 LIS 的 dp 状态,设 \(dp_{i,j}\) 表示考虑到第 \(i\) 个位置,长度为 \(j\) LIS 的最小结尾数字。注意到 \(dp_i\) 这个序列是单调递增的。考虑下面几种转移:

  • \(dp_{i-1,j}<l_i\) 时,\(dp_{i,j+1}=\min(dp_{i,j+1},l)\),找到最后一个满足 \(dp{i-1,x}<l_i\) 的位置 \(x\),不难发现仅有 \(dp_{i,x+1}=l_i\),其余部分都是 \(dp_{i,j}=dp_{i-1,j}\)

  • \(l_i\leq dp_{i-1,j} < r_i\) 时,\(dp_{i,j+1}=\min(dp_{i,j+1}.dp{i-1,j})\)。由于 \(dp_i\) 单增,因而更新一定不劣。故有 \(dp_{i,j+1}=dp_{i-1,j}+1\)

  • \(dp_{i-1,j}\geq ri_i\) 时,不能更新,有 \(dp_{i,j}=dp{i-1,j}\)

观察上面的转移,不难发现我们实际上是维持了 \(0\sim x\) 的下标对应的值不动,在第一部分右侧插入值 \(l_i\);满足第二部分条件的所有下标上对应的值先加 \(1\),然后下标往大平移 \(1\) 位(其实直接在第一部分最右侧插入一个值之后,第二部分平移的目的就达到了,只需要加 \(1\))。最后第二部分平移操作和第三部分不动操作会使得重复 \(1\) 位,需要将第三部分最左侧的那个值删掉。这种复杂的插入、删除、修改操作可以用平衡树解决。

下面给出用 FHQ 实现的代码。

int n;
mt19937 rd(time(NULL)); 
struct FHQ{
	int hp[N],val[N],lson[N],rson[N],sz[N],tag[N],cnt=0,rt=0;
	int addnode(int x){
		val[++cnt]=x,sz[cnt]=1,hp[cnt]=rd();
		return cnt;
	}
	void pushup(int x){
		sz[x]=sz[ls(x)]+sz[rs(x)]+1;
	}
	void pushdown(int x){
		if(ls(x))tag[ls(x)]+=tag[x],val[ls(x)]+=tag[x];
		if(rs(x))tag[rs(x)]+=tag[x],val[rs(x)]+=tag[x];
	    tag[x]=0;
	}
	int merge(int x,int y){//x<y
		if(!x)return y;
		if(!y)return x;
		if(hp[x]>hp[y]){
			pushdown(x),rs(x)=merge(rs(x),y),pushup(x);
			return x;
		}
		else{
			pushdown(y),ls(y)=merge(x,ls(y)),pushup(y);
			return y;
		}
	}
	void split(int nw,int v,int &x,int &y){//按值分裂:小于等于v的在x 
		if(!nw){
			x=y=0;
			return;
		}
		pushdown(nw);
		if(val[nw]<=v)x=nw,split(rs(nw),v,rs(x),y),pushup(x);
		else y=nw,split(ls(nw),v,x,ls(y)),pushup(y);
	}
	void insert(int v){//插入一个值为v的节点 
		int x,y;
		split(rt,v,x,y),rt=merge(merge(x,addnode(v)),y);
	}
	void modify(int l,int r){//区间修改 
		int x,y,z;
		split(rt,l-1,x,y),split(y,r,y,z),tag[y]++,val[y]++,rt=merge(merge(x,y),z);
	}
	int getk(int nw,int k){//找第k小的值 
	    if(sz[nw]<k)return inf;
	    pushdown(nw);
		if(sz[ls(nw)]>=k)return getk(ls(nw),k);
		else if(sz[ls(nw)]==k-1)return val[nw];
		else return getk(rs(nw),k-sz[ls(nw)]-1); 
	}
}T;
int main(){
	read(n);
	while(n--){
		//先删,再平移,最后插入,是一个腾位置的过程,免得平移过去的被删了,插入的被一起平移了. 
		int l,r;
		read(l),read(r);
	    int x,y;
		T.split(T.rt,r-1,x,y);
		if(y){
			int val=T.getk(y,1),z;
			if(val!=inf){
				T.split(y,val,y,z);//删除1个节点,就删除这个根,将左右儿子合并即可. 
	    		y=T.merge(T.lson[y],T.rson[y]),y=T.merge(y,z);
		    }
		}
		T.rt=T.merge(x,y);
		T.modify(l,r-1);
		T.insert(l);
	}
	printf("%d\n",T.sz[T.rt]);
	return 0;
}

P05 用堆维护取最值转移的操作:CF671D Roads in Yusland

简要题意:给定一棵有 \(n\) 个节点、以 \(1\) 为根的有根树,并给定 \(m\) 条形如 \((u,v)\) 的路径,保证 \(v\)\(u\) 本身或其祖先,且每条路径有权值 \(w_i\)。请选择若干条路径,使得它们覆盖树上所有边的同时权值和最小。求出这个最小值。

解法:

套路化地考虑提前钦定,设 \(dp_{i,j}\) 表示 \(i\) 的子树全部被覆盖,且已选择的路径从 \(i\) 往上延伸到了深度 \(j\) 的位置。考虑合并 \(i\)\(i\) 的某个子节点 \(u\) 时,答案可能由 \(dp_i\)\(dp_u\) 中的某一个产生。具体地,\(dp_{i,j}=\min(dp_{i,j}+\min_{k=0}^{dep_i}dp_{u,k},dp_{u,j}+\min_{k=0}^{dep_i}dp_{i,k})\)。这里显然可以用线段树合并的思路维护整体 dp,提前求出两个最小值,合并时先加上去,在取 \(\min\) 即可。这种做法在实现较好的情况下可以通过。

但我们可以考虑空间花费更小的做法。不难发现取 \(\min\) 的操作可以用堆来维护,对于不满足 \(j\leq dep_i\) 的状态懒惰删除(即堆顶出现了不合法的值再删除)解决即可。因此我们在每个节点上维护一个堆,堆里装所有的第二维状态和值即可。这里可以采用左偏树做到 \(n\log n\),也可以堆的启发式合并做到 \(n\log^2 n\)。下面贴出启发式合并的代码,加法 tag 是打在整个堆上的。

int n,m,dep[N];
struct edge{
	int to,nxt;
}e[2*N];
int fir[N],np;
struct work{
	int to,val;
};
vector<work>w[N];
int main();
struct node{
	int upper;
	ll co;
	friend bool operator<(node x,node y){
		return x.co>y.co;
	}
};
priority_queue<node>q[N];
void add(int x,int y){
	e[++np]=(edge){y,fir[x]};
	fir[x]=np;
}
ll ans=0,tag[N];
bool ok=1;
void dfs(int x,int f){
	dep[x]=dep[f]+1;
	for(int i=fir[x];i;i=e[i].nxt){
		int j=e[i].to;
		if(j==f)continue;
		dfs(j,x);
		if(!ok)return;
		if(q[x].size()<q[j].size())swap(tag[x],tag[j]),swap(q[x],q[j]);
		ll del=0,minx=0;
		if(!q[j].empty())del=q[j].top().co;
		if(!q[x].empty())minx=q[x].top().co;
		while(!q[j].empty()){
			node nw=q[j].top();
			q[j].pop();
			nw.co+=minx-del;
			q[x].push(nw);
		}
		tag[x]+=del+tag[j];
	}
	ll minn=0;
	if(!q[x].empty())minn=q[x].top().co;
	rep(i,0,(int)w[x].size()-1)
		q[x].push((node){w[x][i].to,(ll)w[x][i].val+minn});
	if(q[x].empty()){
		ok=0;
		return;
	}
	if(x==1)ans=q[x].top().co+tag[x];
	else{
		while(!q[x].empty()&&dep[q[x].top().upper]>=dep[x])
		    q[x].pop();
		if(q[x].empty())ok=0;
	}
}
int main(){
	read(n),read(m);
	rep(i,1,n-1){
		int x,y;
		read(x),read(y),add(x,y),add(y,x);
	}
	rep(i,1,m){
		int x,y,v;
		read(x),read(y),read(v);
		w[x].push_back((work){y,v});
	}
	if(n==1){
		printf("0\n");
		return 0;
	}
	dfs(1,0);
	if(!ok)printf("-1");
	else printf("%lld\n",ans);
	return 0;
}

完结撒花~❀

posted @ 2023-11-12 21:51  烟山嘉鸿  阅读(388)  评论(0编辑  收藏  举报