【题解】 「NOI2020」命运 树形dp+线段树合并 LOJ3340
Legend
给定一棵 \(n\) 节点的树,你需要把边黑白染色。当然还有 \(m\) 个限制,限制是一条简单路径,且两端满足祖先-后代关系,表示这条路径上所有的边至少有一条是黑色的,问方案数对 \(998 244 353\) 取模的值。
\(1 \le n \le 500000\),\(0 \le m \le 500000\)。
Editorial
两端满足祖先-后代关系这个条件很奇怪,不妨从此下手。
设 \(dp_{i,j}\) 表示从子树 \(i\) 里向上伸出来的链 **上端点深度最大是 \(j\) **的方案数量。
现在考虑到了节点 \(i\),看看加入一棵子树 \(k\) 发生了什么:
\(dp'_{i,j}=\sum\limits_{s=0}^{j} dp_{i,j}\times dp_{k,s} + \sum\limits_{s=0}^{j-1} dp_{i,s} \times dp_{k,j}+ \sum\limits_{s=0}^{dep_{i}}dp_{i,j} \times dp_{k,s}\)。
其中前两个转移是 \((i,k)\) 这条边没选的,最后一个是这条边选了的。
这样就可以 \(O(n)\) 进行转移。可以写出一个 \(O(n \cdot \rm{maxdep})\) 的做法。
把方程写成前缀和形式:
\[\begin{aligned}
dp'_{i,j}&=dp_{i,j}\times S_{k,j} + dp_{k,j} \times S_{i,j-1} + dp_{i,j} \times S_{k,dep_i} \\
&=dp_{i,j}\times(S_{k,j}+S_{k,dep_i})+dp_{k,j}\times S_{i,j-1}
\end{aligned}
\]
考虑怎么快速算这一坨东西。发现前缀和其实可以通过线段树合并的时候顺带求出来。
于是复杂度就变成了 \(O(n \log n)\)。
Code
算是个套路,但是第一次见还是很有趣的。
#include <bits/stdc++.h>
#define LL long long
#define debug(...) fprintf(stderr ,__VA_ARGS__)
#define __FILE(x)\
freopen(#x".in" ,"r" ,stdin);\
freopen(#x".out" ,"w" ,stdout
const int MX = 5e5 + 233;
const LL MOD = 998244353;
int read(){
char k = getchar(); int x = 0;
while(k < '0' || k > '9') k = getchar();
while(k >= '0' && k <= '9') x = x * 10 + k - '0' ,k = getchar();
return x;
}
std::vector<int> limit[MX];
int head[MX] ,tot ,n;
struct edge{
int node ,next;
}h[MX << 1];
void addedge(int u ,int v ,int flg = 1){
h[++tot] = (edge){v ,head[u]} ,head[u] = tot;
if(flg) addedge(v ,u ,0);
}
struct node{
int l ,r;
LL sum ,mul;
node *lch ,*rch;
}*root[MX];
node *newnode(int l ,int r){
node *x = new node;
x->l = l ,x->r = r;
x->sum = 0 ,x->mul = 1;
x->lch = x->rch = nullptr;
return x;
}
void domul(node *x ,LL v){
x->sum = x->sum * v % MOD;
x->mul = x->mul * v % MOD;
}
void pushdown(node *x){
if(x->mul != 1){
if(x->lch != nullptr) domul(x->lch ,x->mul);
if(x->rch != nullptr) domul(x->rch ,x->mul);
x->mul = 1;
}
}
void pushup(node *x){
x->sum = 0;
if(x->lch != nullptr) x->sum = x->lch->sum;
if(x->rch != nullptr) x->sum = (x->sum + x->rch->sum) % MOD;
}
void change(node *x ,int p ,int val){
if(x->l == x->r) return x->sum = val ,void();
int mid = (x->l + x->r) >> 1;
pushdown(x);
if(p <= mid){
if(x->lch == nullptr) x->lch = newnode(x->l ,mid);
change(x->lch ,p ,val);
}else{
if(x->rch == nullptr) x->rch = newnode(mid + 1 ,x->r);
change(x->rch ,p ,val);
}return pushup(x);
}
LL sum(node *x ,int l ,int r){
if(x == nullptr) return 0;
if(l <= x->l && x->r <= r) return x->sum;
pushdown(x);
int mid = (x->l + x->r) >> 1;
LL s = 0;
if(l <= mid) s = sum(x->lch ,l ,r);
if(r > mid) s = (s + sum(x->rch ,l ,r)) % MOD;
return s;
}
node *merge(node *x ,node *y ,LL &s1 ,LL &s2){
if(x == nullptr){
if(y != nullptr){
s1 = (s1 + y->sum) % MOD;
domul(y ,s2);
}
return y;
}
if(y == nullptr){
s2 = (s2 + x->sum) % MOD;
domul(x ,s1);
return x;
}
if(x->l == x->r){
LL tmps2 = s2;
s1 = (s1 + y->sum) % MOD;
s2 = (s2 + x->sum) % MOD;
domul(x ,s1);
x->sum = (x->sum + y->sum * tmps2) % MOD;
}
else{
pushdown(x) ,pushdown(y);
x->lch = merge(x->lch ,y->lch ,s1 ,s2);
x->rch = merge(x->rch ,y->rch ,s1 ,s2);
pushup(x);
}
return x;
}
int dep[MX];
void dfs(int x ,int f ,int depth){
dep[x] = depth;
for(int i = head[x] ,d ; i ; i = h[i].next){
if((d = h[i].node) == f) continue;
dfs(d ,x ,depth + 1);
}
root[x] = newnode(0 ,n);
int mx = 0;
for(auto i : limit[x]){
if(dep[i] > dep[x]) continue;
// change(root[x] ,dep[i] ,1);
mx = std::max(mx ,dep[i]);
}
change(root[x] ,mx ,1);
for(int i = head[x] ,d ; i ; i = h[i].next){
if((d = h[i].node) == f) continue;
LL tmp = sum(root[d] ,0 ,dep[x]) ,tmp2 = 0;
merge(root[x] ,root[d] ,tmp ,tmp2);
}
// debug("sum dp[%d] = %lld\n" ,x ,sum(root[x] ,0 ,n));
}
int main(){
n = read();
for(int i = 1 ,u ,v ; i < n ; ++i){
u = read() ,v = read();
addedge(u ,v);
}
int m = read();
for(int i = 1 ,u ,v ; i <= m ; ++i){
u = read() ,v = read();
limit[u].push_back(v);
limit[v].push_back(u);
}
dfs(1 ,0 ,1);
printf("%lld\n" ,sum(root[1] ,0 ,0));
return 0;
}