【题解】 「NOI2020」命运 树形dp+线段树合并 LOJ3340

Legend

给定一棵 \(n\) 节点的树,你需要把边黑白染色。当然还有 \(m\) 个限制,限制是一条简单路径,且两端满足祖先-后代关系,表示这条路径上所有的边至少有一条是黑色的,问方案数对 \(998 244 353\) 取模的值。

\(1 \le n \le 500000\)\(0 \le m \le 500000\)

Editorial

两端满足祖先-后代关系这个条件很奇怪,不妨从此下手。

\(dp_{i,j}\) 表示从子树 \(i\) 里向上伸出来的链 **上端点深度最大是 \(j\) **的方案数量。

现在考虑到了节点 \(i\),看看加入一棵子树 \(k\) 发生了什么:

\(dp'_{i,j}=\sum\limits_{s=0}^{j} dp_{i,j}\times dp_{k,s} + \sum\limits_{s=0}^{j-1} dp_{i,s} \times dp_{k,j}+ \sum\limits_{s=0}^{dep_{i}}dp_{i,j} \times dp_{k,s}\)

其中前两个转移是 \((i,k)\) 这条边没选的,最后一个是这条边选了的。

这样就可以 \(O(n)\) 进行转移。可以写出一个 \(O(n \cdot \rm{maxdep})\) 的做法。

把方程写成前缀和形式:

\[\begin{aligned} dp'_{i,j}&=dp_{i,j}\times S_{k,j} + dp_{k,j} \times S_{i,j-1} + dp_{i,j} \times S_{k,dep_i} \\ &=dp_{i,j}\times(S_{k,j}+S_{k,dep_i})+dp_{k,j}\times S_{i,j-1} \end{aligned} \]

考虑怎么快速算这一坨东西。发现前缀和其实可以通过线段树合并的时候顺带求出来。

于是复杂度就变成了 \(O(n \log n)\)

Code

算是个套路,但是第一次见还是很有趣的。

#include <bits/stdc++.h>

#define LL long long
#define debug(...) fprintf(stderr ,__VA_ARGS__)
#define __FILE(x)\
	freopen(#x".in" ,"r" ,stdin);\
	freopen(#x".out" ,"w" ,stdout

const int MX = 5e5 + 233;
const LL MOD = 998244353;

int read(){
	char k = getchar(); int x = 0;
	while(k < '0' || k > '9') k = getchar();
	while(k >= '0' && k <= '9') x = x * 10 + k - '0' ,k = getchar();
	return x;
}

std::vector<int> limit[MX];

int head[MX] ,tot ,n;
struct edge{
	int node ,next;
}h[MX << 1];
void addedge(int u ,int v ,int flg = 1){
	h[++tot] = (edge){v ,head[u]} ,head[u] = tot;
	if(flg) addedge(v ,u ,0);
}

struct node{
	int l ,r;
	LL sum ,mul;
	node *lch ,*rch;
}*root[MX];

node *newnode(int l ,int r){
	node *x = new node;
	x->l = l ,x->r = r;
	x->sum = 0 ,x->mul = 1;
	x->lch = x->rch = nullptr;
	return x;
}

void domul(node *x ,LL v){
	x->sum = x->sum * v % MOD;
	x->mul = x->mul * v % MOD;
}

void pushdown(node *x){
	if(x->mul != 1){
		if(x->lch != nullptr) domul(x->lch ,x->mul);
		if(x->rch != nullptr) domul(x->rch ,x->mul);
		x->mul = 1;
	}
}

void pushup(node *x){
	x->sum = 0;
	if(x->lch != nullptr) x->sum = x->lch->sum;
	if(x->rch != nullptr) x->sum = (x->sum + x->rch->sum) % MOD;
}

void change(node *x ,int p ,int val){
	if(x->l == x->r) return x->sum = val ,void();
	int mid = (x->l + x->r) >> 1;
	pushdown(x);
	if(p <= mid){
		if(x->lch == nullptr) x->lch = newnode(x->l ,mid);
		change(x->lch ,p ,val);
	}else{
		if(x->rch == nullptr) x->rch = newnode(mid + 1 ,x->r);
		change(x->rch ,p ,val);
	}return pushup(x);
}

LL sum(node *x ,int l ,int r){
	if(x == nullptr) return 0;
	if(l <= x->l && x->r <= r) return x->sum;
	pushdown(x);
	int mid = (x->l + x->r) >> 1;
	LL s = 0;
	if(l <= mid) s = sum(x->lch ,l ,r);
	if(r > mid) s = (s + sum(x->rch ,l ,r)) % MOD;
	return s;	
}

node *merge(node *x ,node *y ,LL &s1 ,LL &s2){
	if(x == nullptr){
		if(y != nullptr){
			s1 = (s1 + y->sum) % MOD;
			domul(y ,s2); 
		}
		return y;
	}
	if(y == nullptr){
		s2 = (s2 + x->sum) % MOD;
		domul(x ,s1);
		return x;
	}
	if(x->l == x->r){
		LL tmps2 = s2;
		s1 = (s1 + y->sum) % MOD;
		s2 = (s2 + x->sum) % MOD;
		domul(x ,s1);
		x->sum = (x->sum + y->sum * tmps2) % MOD;
	}
	else{
		pushdown(x) ,pushdown(y);
		x->lch = merge(x->lch ,y->lch ,s1 ,s2);
		x->rch = merge(x->rch ,y->rch ,s1 ,s2);
		pushup(x);
	}
	return x;
}

int dep[MX];
void dfs(int x ,int f ,int depth){
	dep[x] = depth;
	for(int i = head[x] ,d ; i ; i = h[i].next){
		if((d = h[i].node) == f) continue;
		dfs(d ,x ,depth + 1);
	}

	root[x] = newnode(0 ,n);
	int mx = 0;
	for(auto i : limit[x]){
		if(dep[i] > dep[x]) continue;
		// change(root[x] ,dep[i] ,1);
		mx = std::max(mx ,dep[i]);
	}
	change(root[x] ,mx ,1);

	for(int i = head[x] ,d ; i ; i = h[i].next){
		if((d = h[i].node) == f) continue;
		LL tmp = sum(root[d] ,0 ,dep[x]) ,tmp2 = 0;
		merge(root[x] ,root[d] ,tmp ,tmp2);
	}
	// debug("sum dp[%d] = %lld\n" ,x ,sum(root[x] ,0 ,n));
}

int main(){
	n = read();
	for(int i = 1 ,u ,v ; i < n ; ++i){
		u = read() ,v = read();
		addedge(u ,v);
	}
	int m = read();
	for(int i = 1 ,u ,v ; i <= m ; ++i){
		u = read() ,v = read();
		limit[u].push_back(v);
		limit[v].push_back(u);
	}
	dfs(1 ,0 ,1);
	printf("%lld\n" ,sum(root[1] ,0 ,0));
	return 0;
}
posted @ 2020-09-26 15:24  Imakf  阅读(191)  评论(0编辑  收藏  举报