[NOI2020]命运 题解
考虑\(dp.\)
记\(dp_{x,i}\)表示\(x\)子树内的所有边以及\(x\)到\(x\)的父亲的边的状态\((\) 是否选取 \()\) 都决定好了\(,\)目前还没有被解决的限制的最大深度\(\leq i\)的方案数\(.\)
不难发现转移时相当于每个儿子的\(dp\)数组对位相乘\(,\)最后要求一个当前\(dp\)数组的和加到全局上\(.\)
可以用带\(tag\)的线段树合并维护\(.\)
我的考场代码写的是启发式合并\(.\)
启发式合并代码\(:\)
#include <bits/stdc++.h>
#define LL long long
using namespace std;
inline int read(){
static int x; static char c; x = 0,c = getchar();
while (!isdigit(c)) c = getchar();
while (isdigit(c)) x = x * 10 + c - '0',c = getchar();
return x;
}
const int N = 500050,P = 998244353;
int To[N<<1],Ne[N<<1],He[N],_;
inline void adde(int x,int y){
++_; To[_] = y,Ne[_] = He[x],He[x] = _;
++_; To[_] = x,Ne[_] = He[y],He[y] = _;
}
int n,m,fa[N],dpt[N],lim[N];
inline void dfs(int x){
dpt[x] = dpt[fa[x]] + 1;
for (int y,p = He[x]; p ; p = Ne[p]) if ((y=To[p])^fa[x]) fa[y] = x,dfs(y);
}
const int V = N * 60;
inline void upd(int &x,int y){ x = (x+y>=P)?(x+y-P):(x+y); }
int lc[V],rc[V],val[V],mul[V],add[V],siz[V],cnto;
int stk[V],top;
inline int New(){
static int o; if (top) o = stk[top],--top; else o = ++cnto;
lc[o] = rc[o] = add[o] = siz[o] = val[o] = 0,mul[o] = 1;
return o;
}
inline void DD(int x){ stk[++top] = x; }
inline void tmul(int o,int v){ if (o) mul[o] = (LL)mul[o] * v % P,add[o] = (LL)add[o] * v % P,val[o] = (LL)val[o] * v % P; }
inline void tadd(int o,int v){ if (o) upd(add[o],v),upd(val[o],v); }
inline void down(int o){
if (mul[o] ^ 1) tmul(lc[o],mul[o]),tmul(rc[o],mul[o]),mul[o] = 1;
if (add[o]) tadd(lc[o],add[o]),tadd(rc[o],add[o]),add[o] = 0;
}
inline void up(int o){ val[o] = val[rc[o] ? rc[o] : lc[o]],siz[o] = siz[lc[o]] + siz[rc[o]]; }
int pp,vv;
inline void Del(int &o,int l,int r){
if (!o || r < pp) return; if (l >= pp){ DD(o),o = 0; return; }
down(o); int mid = l+r>>1; Del(lc[o],l,mid); Del(rc[o],mid+1,r);
up(o); if (siz[o] == 0) DD(o),o = 0;
}
inline void Ins(int &o,int l,int r){
if (!o) o = New(); if (l == r){ siz[o] = 1,val[o] = vv; return; }
down(o); int mid = l+r>>1; if (pp <= mid) Ins(lc[o],l,mid); else Ins(rc[o],mid+1,r); up(o);
}
int ll,rr;
inline void Mul(int o,int l,int r){
if (!o) return; if (ll <= l && rr >= r){ tmul(o,vv); return; }
down(o); int mid = l+r>>1; if (ll <= mid) Mul(lc[o],l,mid); if (rr > mid) Mul(rc[o],mid+1,r); up(o);
}
inline void Add(int o,int l,int r){
if (!o) return; if (ll <= l && rr >= r){ tadd(o,vv); return; }
down(o); int mid = l+r>>1; if (ll <= mid) Add(lc[o],l,mid); if (rr > mid) Add(rc[o],mid+1,r); up(o);
}
inline bool Is(int o,int l,int r){
if (!o) return 0; if (l == r) return 1;
down(o); int mid = l+r>>1; if (pp <= mid) return Is(lc[o],l,mid); return Is(rc[o],mid+1,r);
}
int qans,qi;
inline void Query(int o,int l,int r){
if (!o || qi >= r || l >= pp) return; if (r < pp){ qi = r,qans = val[o]; return; }
down(o); int mid = l+r>>1; Query(rc[o],mid+1,r); Query(lc[o],l,mid);
}
inline void radd(int rt,int l,int r,int v){ ll = l,rr = r,vv = v,Add(rt,0,n); }
inline void rmul(int rt,int l,int r,int v){ ll = l,rr = r,vv = v,Mul(rt,0,n); }
inline void rins(int &rt,int p,int v){ pp = p,vv = v,Ins(rt,0,n); }
inline void rdel(int &rt,int p){ pp = p,Del(rt,0,n); }
inline bool ris(int rt,int p){ pp = p; return Is(rt,0,n); }
int ti[N],tv[N],cntt;
inline void Dfs(int o,int l,int r){
if (!o) return; if (l == r){ ++cntt; ti[cntt] = l,tv[cntt] = val[o]; DD(o); return; }
down(o); int mid = l+r>>1; Dfs(lc[o],l,mid); Dfs(rc[o],mid+1,r); DD(o);
}
int ans;
inline void Ask(int o,int l,int r){
if (!o) return; if (l == r){ ans = val[o]; return; }
down(o); int mid = l+r>>1; if (lc[o]) Ask(lc[o],l,mid); else Ask(rc[o],mid+1,r);
}
inline void Dfs2(int o,int l,int r){
if (!o) return; if (l == r){ ++cntt; ti[cntt] = l,tv[cntt] = val[o]; return; }
down(o); int mid = l+r>>1; Dfs2(lc[o],l,mid); Dfs2(rc[o],mid+1,r);
}
inline int Merge(int rt1,int rt2){
if (siz[rt1] < siz[rt2]) swap(rt1,rt2);
cntt = 0,Dfs(rt2,0,n);
ti[++cntt] = n+1;
for (int i = cntt-1; i >= 1; --i){
if (!ris(rt1,ti[i])){
pp = ti[i],qi = -1,Query(rt1,0,n);
rins(rt1,ti[i],qans);
}
rmul(rt1,ti[i],ti[i+1]-1,tv[i]);
}
return rt1;
}
int T[N];
inline void dp(int x){
if (lim[x]) rins(T[x],0,0); rins(T[x],lim[x],1);
for (int y,p = He[x]; p ; p = Ne[p]) if ((y=To[p])^fa[x]) dp(y),T[x] = Merge(T[x],T[y]);
if (x == 1){ Ask(T[1],0,n); return; }
radd(T[x],0,n,val[T[x]]);
rdel(T[x],dpt[fa[x]]);
//cerr << "print " << x <<'\n',Print(T[x]);
}
int main(){
// freopen("destiny.in","r",stdin);
// freopen("destiny.out","w",stdout);
int i,x,y;
n = read();
for (i = 1; i < n; ++i) x = read(),y = read(),adde(x,y);
dfs(1);
for (i = 1; i <= n; ++i) lim[i] = 0;
m = read();
while (m--) x = read(),y = read(),lim[y] = max(lim[y],dpt[x]);
dp(1);
cout << ans << '\n';
}