[NOI2020] 命运
链接
题目大意
给定一颗树,需要给每条边赋边权0/1,并且对于给定的 \(k\) 条链,保证链两端为祖先-儿子关系,要求每条链上至少有一个1。求方案数。
题解
考虑dp,用 \(f_{i,j}\) 表示下端点在 \(i\) 的子树中且未包含1的所有链里上端点最深为 \(j\)。即除此之外子树中的部分都处理好了。
显然有树形dp转移方程:
\[f'_{u,i}=\sum_{j=0}^{dep_i}f_{u,i}f_{v,j}+\sum_{j=0}^{i}f_{u,i}f_{v,j}+\sum_{j=0}^{i-1}f_{v,i}f_{u,j}
\]
前缀和一下就可以做到 \(O(n^2)\),然而我线上赛时居然以为这是 \(O(n^3)\) 的。
然后考虑转换成前缀和的形式,即 \(g_{i,j}=\sum_{k=0}^{j} f_{i,k}\)
那么:
\[f'_{u,i}=f_{u,i}\times (g_{v,dep_i}+ g_{v,i})+g_{u,i-1}\times f_{v,i}
\]
考虑使用线段树合并,即在合并时先处理左半部分,处理完后顺便处理前缀和,然后用前缀和更新右半部分。
时间复杂度 \(O(n\log n)\)。
#include<iostream>
#include<cstdio>
#include<cstring>
#define N 500010
#define mod 998244353
using namespace std;
int nxt[N<<1],to[N<<1],head[N],cnt;
void add(int u,int v)
{
nxt[++cnt]=head[u];
to[cnt]=v;
head[u]=cnt;
}
int dep[N],fa[N];
int val[N*40],tag[N*40],ls[N*40],rs[N*40],tot;
void set_tag(int u,int v)
{
if(!u) return;
val[u]=1ll*val[u]*v%mod;tag[u]=1ll*tag[u]*v%mod;
}
void push_down(int u)
{
if(tag[u]==1) return;
set_tag(ls[u],tag[u]),set_tag(rs[u],tag[u]);tag[u]=1;
}
void insert(int &u,int l,int r,int p)
{
u=++tot,tag[u]=val[u]=1;
if(l==r) return;
int mid=(l+r)>>1;
if(p<=mid) insert(ls[u],l,mid,p);
else insert(rs[u],mid+1,r,p);
}
int answer(int u,int l,int r,int L,int R)
{
if(!u) return 0;
if(L<=l && r<=R) return val[u];
int mid=(l+r)>>1,res=0;
push_down(u);
if(L<=mid) res+=answer(ls[u],l,mid,L,R);
if(R>mid) res+=answer(rs[u],mid+1,r,L,R);
return res%mod;
}
int merge(int x,int y,int l,int r,int &s1,int &s2)
{
if(!x || !y)
{
if(y)s1=(s1+val[y])%mod,set_tag(y,s2);
if(x)s2=(s2+val[x])%mod,set_tag(x,s1);
return x+y;
}
if(l==r)
{
int tx=val[x],ty=val[y];
val[x]=(1ll*val[x]*(s1+val[y])+1ll*val[y]*s2)%mod;
s1=(s1+ty)%mod;
s2=(s2+tx)%mod;
return x;
}
push_down(x),push_down(y);
int mid=(l+r)>>1;
ls[x]=merge(ls[x],ls[y],l,mid,s1,s2);
rs[x]=merge(rs[x],rs[y],mid+1,r,s1,s2);
val[x]=(val[ls[x]]+val[rs[x]])%mod;
return x;
}
int up[N],root[N],n;
void dfs(int u,int p)
{
dep[u]=dep[p]+1;fa[u]=p;
for(int i=head[u];i;i=nxt[i])
if(to[i]!=p) dfs(to[i],u);
}
void solve(int u)
{
insert(root[u],0,n,up[u]);
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(v==fa[u]) continue;
solve(v);
int p=answer(root[v],0,n,0,dep[u]),p2=0;
root[u]=merge(root[u],root[v],0,n,p,p2);
}
}
int main()
{
int m;
scanf("%d",&n);
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y),add(y,x);
}
dfs(1,0);
scanf("%d",&m);
for(int i=1;i<=m;i++)
{
int x,y;
scanf("%d%d",&x,&y);
up[y]=max(up[y],dep[x]);
}
solve(1);
printf("%d\n",answer(root[1],0,n,0,0));
return 0;
}
本文来自博客园,作者:Flying2018,转载请注明原文链接:https://www.cnblogs.com/Flying2018/p/13532968.html