[NOI2020]命运 题解
题意分析
给出 $m$ 条直链,要求每条链上至少选中一条边,求方案数。
思路分析
想到用 DP 求解,自下向上统计。可以发现,对于下端点相同的链,若上端点最深的链满足条件,则这些链都会满足条件。另外,若一些对于经过一个节点的链,若比这个节点更深的边全部不选,则可以等同于这些链的下端点相同。
因此,可以设 $f_{x,i}$ 表示经过 $x$ 的未满足条件的上端点的最大深度为 $i$ 时的方案数,若没有经过 $x$ 的链则为 $0$ 。题目所求即为所有条件都满足时的方案数,即 $f_{root,0}$ 。对于 $x$ 的子节点 $y$ ,若选中边 $(x,y)$ ,则经过所有经过 $y$ 的边都可以满足条件;若不选,则只有上端点深度 $\leq i$ 的链才可以满足条件。即:
$$f'_{x,i}=f_{x,i} *\sum_{j=0}^{d_x} f_{y,j}+\sum_{j=0}^i\sum_{k=0,max(j,k)=i}^i f_{x,j}*f_{y,k}$$
整理得:
$$f'_{x,i}=f_{x,i} *(\sum_{j=0}^{d_x} f_{y,j}+\sum_{j=0}^{i} f_{y,j})+f_{y,i}\sum_{j=0}^{i-1} f_{x,j}$$
发现这个式子显然可以用前缀和优化。想到用线段树合并来做,其中 $\sum_{j=0}^{d_x} f_{y,j}$ 对于同一个 $y$ 来说是定值,可以先查询;剩下的元素都可以在线段树合并的过程中计算。可以维护一个乘法标记,前缀和可以通过先走左子树再走右子树来计算。
#include<iostream>
#include<cstdio>
#include<vector>
#define ll long long
using namespace std;
const int N=5e5+100;
const ll P=998244353;
struct Seg
{
int lson,rson;
ll sum,mul;
#define lson(i) t[i].lson
#define rson(i) t[i].rson
#define sum(i) t[i].sum
#define mul(i) t[i].mul
}t[N<<5];
int n,m,tot,cnt;
int head[N],ver[2*N],Next[2*N];
int r[N],d[N];
vector<int> query[N];
void add(int x,int y)
{
ver[++tot]=y,Next[tot]=head[x],head[x]=tot;
ver[++tot]=x,Next[tot]=head[y],head[y]=tot;
}
void add_query(int x,int y)
{
query[y].push_back(x);
}
void spread(int p)
{
if(lson(p))
{
sum(lson(p))=sum(lson(p))*mul(p)%P;
mul(lson(p))=mul(lson(p))*mul(p)%P;
}
if(rson(p))
{
sum(rson(p))=sum(rson(p))*mul(p)%P;
mul(rson(p))=mul(rson(p))*mul(p)%P;
}
mul(p)=1;
}//下传标记
int change(int p,int l,int r,int k)
{
if(!p)
p=++cnt,sum(p)=mul(p)=1;//新建节点
if(l==r)
return p;
spread(p);
int mid=l+r>>1;
if(k<=mid)
lson(p)=change(lson(p),l,mid,k);
else
rson(p)=change(rson(p),mid+1,r,k);
return p;
}//新树,类似单点修改
ll ask(int p,int l,int r,int k)
{
if(!p)
return 0;
if(r<=k)
return sum(p);
spread(p);
int mid=l+r>>1;
ll now=0;
now=(now+ask(lson(p),l,mid,k))%P;
if(k>mid)
now=(now+ask(rson(p),mid+1,r,k))%P;
return now;
}//区间查询 [0,k]
int merge(int x,int y,int l,int r,ll &sumx,ll &sumy)
{
if(!x && !y)
return 0;
if(!x)
{
sumy=(sumy+sum(y))%P;
sum(y)=(sum(y)*sumx)%P;
mul(y)=(mul(y)*sumx)%P;
return y;
}
if(!y)
{
sumx=(sumx+sum(x))%P;
sum(x)=(sum(x)*sumy)%P;
mul(x)=(mul(x)*sumy)%P;
return x;
}
if(l==r)
{
int nowx=sum(x);//sum(x) 的值会改变,先存储
sumy=(sumy+sum(y))%P;
sum(x)=(sum(x)*sumy+sum(y)*sumx)%P;
sumx=(sumx+nowx)%P;//f[x][i]*f[y][i] 算过一遍就不重复计算
return x;
}
spread(x),spread(y);
int mid=l+r>>1;
lson(x)=merge(lson(x),lson(y),l,mid,sumx,sumy);
rson(x)=merge(rson(x),rson(y),mid+1,r,sumx,sumy);
sum(x)=(sum(lson(x))+sum(rson(x)))%P;
return x;
}
void solve(int x,int fa)
{
d[x]=d[fa]+1;
int maxd=0;
for(int i=0;i<query[x].size();i++)
maxd=max(maxd,d[query[x][i]]);//找到最大深度
r[x]=change(r[x],0,n,maxd);//新节点
for(int i=head[x];i;i=Next[i])
{
int y=ver[i];
if(y==fa)
continue;
solve(y,x);
ll nowy=ask(r[y],0,n,d[x]),nowx=0;//先查询定值
r[x]=merge(r[x],r[y],0,n,nowx,nowy);
}
}
int main()
{
scanf("%d",&n);
for(int i=1,x,y;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y);
}
scanf("%d",&m);
for(int i=1,x,y;i<=m;i++)
{
scanf("%d%d",&x,&y);
add_query(x,y);
}
solve(1,0);
printf("%lld",ask(r[1],0,n,0));
return 0;
}