P6773 [NOI2020] 命运

P6773 [NOI2020] 命运

注意到如果有两条限制 \((x,y)\)\((x,z)\)(满足 \(\text{dep}_z < \text{dep}_y \le \text{dep}_x\),则当 \((x,y)\) 满足条件时,\((x,z)\) 也一定满足条件。

\(f_{x,i}\) 表示当前节点 \(x\) 子树内的边已经染色完成,当前子树内向上延伸的最大 \(\text{dep}\)\(i\) 的方案数。考虑 \(x\) 与子节点 \(v\) 的一条边的染色情况。

  • 当这条边染色时,\(v\) 子树内的所有点都染色完成,则 \(i\) 一定时由 \(x\) 的其他子树贡献的,即此时贡献为 \(f_{x,i}\times \sum_{j=0}^{\text{dep}_x}f_{v,j}\)
  • 当这条边不染色时
    • \(i\)\(x\) 其他子树贡献时,\(v\) 上方延伸小于 \(x\),即贡献为 \(f_{x,i}\times \sum_{j=0}^i f_{v,j}\)
    • \(i\)\(v\) 贡献时,\(v\) 的延伸为 \(i\)\(x\) 其他子树上方延伸小于 \(x\),即贡献为 \(f_{v,i}\times \sum_{j=0}^i f_{x,j}\)
    • 由于上面当 \(x\) 的延伸与 \(v\) 的延伸均为 \(i\)\(f_{x,i}\times f_{v,i}\) 计算了两次,所以要减一次。

综上,可以得到最终的转移方程为

\[\begin{aligned} f_{x,i}&=f_{x,i}\times\sum_{j=0}^{\text{dep}_x}f_{v,j}+f_{x,i}\times \sum_{j=0}^i f_{v,j}+f_{v,i}\times \sum_{j=0}^i f_{x,j}-f_{x,i}\times f_{v,i}\\ &=f_{x,i}\times\sum_{j=0}^{\text{dep}_x}f_{v,j}+f_{x,i}\times \sum_{j=0}^i f_{v,j}+f_{v,i}\times \sum_{j=0}^{i-1} f_{x,j}\\ &=f_{x,i}\times(\sum_{j=0}^{\text{dep}_x}f_{v,j}+\sum_{j=0}^i f_{v,j})+f_{v,i}\times \sum_{j=0}^{i-1} f_{x,j} \end{aligned} \]

\(g_{x,i} = \sum_{j=0}^i f_{x,i}\),则有

\[f_{x,i}=f_{x,i}\times(g_{v,\text{dep}_x}+g_{v,i})+f_{v,i}\times g_{x,i-1} \]

至此,直接转移可以得到 \(\mathcal{O}(n^2)\) 的做法。


考虑优化,由于出现前缀和形式,考虑将 \(\text{dp}\) 的第二维放在权值线段树中维护,并使用线段树合并将信息从子节点转移到父节点。

由于转移时转移顺序会影响 \(g\) 的值,注意到仅有前缀和形式,考虑使用左中右转移顺序,先递归左子树,再递归右子树。

假设当前合并 \(x,y\) 两棵线段树,合并的区间为 \([l,r]\)

  • \(x,y\) 均为空时,直接返回即可。
  • \(x\) 为空,\(y\) 非空时,在此区域内,\(f_{x,i}=0\),并且有关 \(x\) 线段树的权值和不再增长,此时转移的右侧仅有 \(f_{v,i}\times g_{x,i-1}\),由于 \(\forall i \in [l,r],f_{x,i}=0\),则 \(g_{x,i-1}=g_{x,l-1}\),即对 \(y\) 线段树原有 \([l,r]\) 区间的每个位置乘上 \(g_{x,l-1}\) 即可。
  • \(x\) 非空,\(y\) 为空时,在此区域内,\(f_{v,i} =0\),并且有关 \(y\) 线段树的权值和不再增长, 此时转移的右侧仅有 \(f_{x,i}\times (g_{v,\text{dep}_x}+g_{v,i})\)。由于 \(\forall i \in [l,r],f_{v,i}=0\),则 \(g_{v,i}=g_{v,l-1}\),即对 \(x\) 线段树原有 \([l,r]\) 区间的每个位置乘上 \((g_{v,\text{dep}_x}+g_{v,i})\) 即可。
  • \(x,y\) 均非空时
    • 若当前是叶子节点,则按照上面的方程直接将 \(i\) 这一位进行转移即可。
    • 若当前不是叶子节点,先递归左区间,后递归右区间,保证进入右区间时左区间的 \(f\) 已经累加进 \(g\)

至此,可以得到时间复杂度与空间复杂度均为 \(\mathcal{O}(n \log n)\) 的做法。

code
#include<bits/stdc++.h>
using namespace std;
namespace IO{
	template<typename T>inline bool read(T &x){
		x=0;
		char ch=getchar();
		bool flag=0,ret=0;
		while(ch<'0'||ch>'9') flag=flag||(ch=='-'),ch=getchar();
		while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar(),ret=1;
		x=flag?-x:x;
        return ret;
	}
	template<typename T,typename ...Args>inline bool read(T& a,Args& ...args){
	    return read(a)&&read(args...);
	}
	template<typename T>void prt(T x){
		if(x>9) prt(x/10);
		putchar(x%10+'0');
	}
	template<typename T>inline void put(T x){
		if(x<0) putchar('-'),x=-x;
		prt(x);
	}
	template<typename T>inline void put(char ch,T x){
		if(x<0) putchar('-'),x=-x;
		prt(x);
		putchar(ch);
	}
	template<typename T,typename ...Args>inline void put(T a,Args ...args){
	    put(a);
		put(args...);
	}
	template<typename T,typename ...Args>inline void put(const char ch,T a,Args ...args){
	    put(ch,a);
		put(ch,args...);
	}
	inline void put(string s){
		for(int i=0,sz=s.length();i<sz;i++) putchar(s[i]);
	}
	inline void put(const char* s){
		for(int i=0,sz=strlen(s);i<sz;i++) putchar(s[i]);
	}
}
using namespace IO;
#define N 500005
#define mod 998244353
#define ll long long
int n,m,head[N],cnt,idx,rt[N],dep[N];
struct edge{
	int v,nxt;
}e[N<<1];
struct node{
	int ls,rs,sum,tag;
}t[N*25];
#define lc(x) t[x].ls
#define rc(x) t[x].rs
inline void add(int u,int v){
	e[++cnt]=(edge){v,head[u]},head[u]=cnt;
}
vector<int> g[N];
inline void push_up(int x){
	t[x].sum=(t[lc(x)].sum+t[rc(x)].sum)%mod;
}
inline void push_tag(int x,int val){
	t[x].sum=(ll)t[x].sum*val%mod;
	t[x].tag=(ll)t[x].tag*val%mod;
}
inline void push_down(int x){
	if(t[x].tag==1) return;
	if(lc(x)) push_tag(lc(x),t[x].tag);
	if(rc(x)) push_tag(rc(x),t[x].tag);
	t[x].tag=1;
}
inline void update(int &x,int l,int r,int pos){
	if(!x) x=++idx,t[x].sum=t[x].tag=1;
	if(l==r) return;
	int mid=l+r>>1;
	if(pos<=mid) update(lc(x),l,mid,pos);
	else update(rc(x),mid+1,r,pos);
	push_up(x);
}
inline int query(int x,int l,int r,int pre){
	if(!x) return 0;
	if(r<=pre) return t[x].sum;
	int mid=l+r>>1;
	push_down(x);
	return (query(lc(x),l,mid,pre)+(pre>mid?query(rc(x),mid+1,r,pre):0))%mod;
}
inline int merge(int x,int y,int l,int r,int &s1,int &s2){
	if(!x&&!y) return 0;
	push_down(x),push_down(y);
	if(!x){
		s1=(s1+t[y].sum)%mod;
		push_tag(y,s2);
		return y;
	}
	if(!y){
		s2=(s2+t[x].sum)%mod;
		push_tag(x,s1);
		return x;
	}
	if(l==r){
		int prex=t[x].sum,prey=t[y].sum;
		s1=(s1+prey)%mod;
		t[x].sum=((ll)t[x].sum*s1%mod+(ll)t[y].sum*s2%mod)%mod;
		s2=(s2+prex)%mod;
		return x;
	}
	int mid=l+r>>1;
	lc(x)=merge(lc(x),lc(y),l,mid,s1,s2);
	rc(x)=merge(rc(x),rc(y),mid+1,r,s1,s2);
	push_up(x);
	return x;
}
inline void dfs(int x,int fa){
	dep[x]=dep[fa]+1;
	int maxn=0;
	for(auto v:g[x]) maxn=max(maxn,dep[v]);
	update(rt[x],0,n,maxn);
	for(int i=head[x];i;i=e[i].nxt){
		int v=e[i].v;
		if(v==fa) continue;
		dfs(v,x);
		int s1=query(rt[v],0,n,dep[x]),s2=0;
		rt[x]=merge(rt[x],rt[v],0,n,s1,s2);
	}
}
int main(){
	read(n);
	for(int i=1,u,v;i<n;i++)
		read(u,v),add(u,v),add(v,u);
	read(m);
	for(int i=1,u,v;i<=m;i++) 
		read(u,v),g[v].push_back(u);
	dfs(1,0);
	put('\n',query(rt[1],0,n,0));
	return 0;
}
posted @ 2022-10-06 15:16  fzj2007  阅读(21)  评论(0编辑  收藏  举报