P6773 [NOI2020]命运
整体DP
很明显计算答案需要用容斥计算,如果暴力容斥的话,就是枚举哪些路径不符合条件,在这些路径的并集中的边都不能取,其他边任意取,设当前取了$i$条路径,那么对答案的贡献是$(-1)^i2^{n-1-Union}$
但是可以发现这个路径是自下往上的,可以考虑树形DP,设$dp[i][j][k]$表示在i的子树内,已选$k$条路径的上面那个端点的深度的最小值为$j$的方案数,特别的如果在这个子树内不选任何一条路径的话,$j$为$maxde$
但其实在容斥的时候,我们并不关心这个$k$的具体取值,只要关心其奇偶性即可,那么可以去掉这一维,在设初始值的时候将$-1$乘到值里面,在之后乘起来的时候,$-1$会帮助数值自动变号
考虑一个儿子一个儿子更新$dp[x][i]$,分情况讨论
如果$i\leqslant de[x]$
$dp[x][i]=\sum_{j=i+1}^{maxde}2dp[u][j]dp[x][i]+\sum_{j=i+1}^{de[x]}dp[u][j]dp[x][i]+\sum_{j=i}^{maxde}dp[u][i]dp[x][j]$
第一部分表示$u$这棵子树内的最浅的那个不能选的路径是小于当前$x$的深度,那么$u->x$这条边是可以任意选择;第二部分就表示$u$中最浅的点比$i$深,那么当前最浅的那个节点依然是$i$;第三部分表示$u$中最浅的点比$i$浅,那么当前最浅的点需要更新
如果$i>de[x]$
$dp[x][i]=\sum_{j=i+1}^{maxde}2dp[u][j]dp[x][i]+\sum_{j=i}^{maxde}2dp[u][i]dp[x][j]$
这也是类似的
可以发现第二维不为$0$的取值是较少的,只有在较前的点才会变多,那么就用线段树合并维护这个DP(跟[PKUWC2018]Minimax类似)
但是这个DP的方程需要分类讨论,就很难直观的进行维护修改,考虑如何用一个式子来表示这个DP方程
#include <bits/stdc++.h> #define mod 998244353 using namespace std; const int N=5*1e5+100; int n,m,w,de[N],maxde,last[N],root[N],cnt; int tot,first[N],nxt[N*2],point[N*2]; struct node { int ls,rs; long long dp,tag; }sh[N*40]; inline void add(long long &a,long long b){a=(a+b);((a>mod)?a-=mod:a=a);} inline void del(long long &a,long long b){a=(a-b+mod)%mod;} inline void mul(long long &a,long long b){a=(a*b)%mod;} inline bool cmp(int a,int b){return(de[a]<de[b]);} inline int read() { int f=1,x=0;char s=getchar(); while(s<'0'||s>'9'){if(s=='-')f=-1;s=getchar();} while(s>='0'&&s<='9'){x=x*10+s-'0';s=getchar();} return x*f; } inline void add_edge(int x,int y) { tot++; nxt[tot]=first[x]; first[x]=tot; point[tot]=y; } void dfs(int x,int fa) { for (int i=first[x];i!=-1;i=nxt[i]) { int u=point[i]; if (u==fa) continue; de[u]=de[x]+1; dfs(u,x); } } inline void pushup(int x) { sh[x].dp=(sh[sh[x].ls].dp+sh[sh[x].rs].dp)%mod; } inline void pushdown(int x) { if (sh[x].tag==1) return; if (sh[x].ls) mul(sh[sh[x].ls].dp,sh[x].tag),mul(sh[sh[x].ls].tag,sh[x].tag); if (sh[x].rs) mul(sh[sh[x].rs].dp,sh[x].tag),mul(sh[sh[x].rs].tag,sh[x].tag); sh[x].tag=1; } int insert(int x,int l,int r,int wh,int v) { if (!x) x=++cnt; sh[x].tag=1; if (l==r) { sh[x].dp=v; return x; } int mid=(l+r)>>1; if (wh<=mid) sh[x].ls=insert(sh[x].ls,l,mid,wh,v); else sh[x].rs=insert(sh[x].rs,mid+1,r,wh,v); pushup(x); return x; } void change(int x,int l,int r,int ll,int rr) { if (ll<=l && rr>=r) { mul(sh[x].dp,2);mul(sh[x].tag,2); return; } int mid=(l+r)>>1; pushdown(x); if (ll<=mid) change(sh[x].ls,l,mid,ll,rr); if (rr>mid) change(sh[x].rs,mid+1,r,ll,rr); pushup(x); } int merge(int a,int b,int l,int r,long long sa,long long sb,long long pa,long long pb) { if (!a) { mul(sh[b].dp,(sa-pa+mod)%mod); mul(sh[b].tag,(sa-pa+mod)%mod); return b; } if (!b) { mul(sh[a].dp,(sb-pb+mod)%mod); mul(sh[a].tag,(sb-pb+mod)%mod); return a; } if (l==r) { add(pa,sh[a].dp);add(pb,sh[b].dp); sh[a].dp=(sh[a].dp*(sb-pb+mod)%mod+sh[b].dp*(sa-pa+mod)%mod+sh[a].dp*sh[b].dp%mod)%mod; return a; } int mid=(l+r)>>1;long long ta=pa,tb=pb; pushdown(a);pushdown(b); add(ta,sh[sh[a].ls].dp);add(tb,sh[sh[b].ls].dp); sh[a].ls=merge(sh[a].ls,sh[b].ls,l,mid,sa,sb,pa,pb); sh[a].rs=merge(sh[a].rs,sh[b].rs,mid+1,r,sa,sb,ta,tb); pushup(a); return a; } void dfs1(int x,int fa) { if (last[x]) root[x]=insert(root[x],1,maxde,last[x],mod-1); root[x]=insert(root[x],1,maxde,maxde,1); for (int i=first[x];i!=-1;i=nxt[i]) { int u=point[i]; if (u==fa) continue; dfs1(u,x); change(root[u],1,maxde,de[x]+1,maxde); root[x]=merge(root[x],root[u],1,maxde,sh[root[x]].dp,sh[root[u]].dp,0,0); } } int main() { tot=-1; memset(first,-1,sizeof(first)); memset(nxt,-1,sizeof(nxt)); n=read(); for (int i=1;i<n;i++) { int u=read(),v=read(); add_edge(u,v);add_edge(v,u); } de[1]=1; dfs(1,1); for (int i=1;i<=n;i++) maxde=max(maxde,de[i]); maxde++; m=read(); for (int i=1;i<=m;i++) { int u=read(),v=read(); last[v]=max(last[v],de[u]); } dfs1(1,1); printf("%lld\n",sh[root[1]].dp); }