[NOI2020] 命运

链接

题目大意

给定一颗树,需要给每条边赋边权0/1,并且对于给定的 \(k\) 条链,保证链两端为祖先-儿子关系,要求每条链上至少有一个1。求方案数。

题解

考虑dp,用 \(f_{i,j}\) 表示下端点在 \(i\) 的子树中且未包含1的所有链里上端点最深为 \(j\)。即除此之外子树中的部分都处理好了。

显然有树形dp转移方程:

\[f'_{u,i}=\sum_{j=0}^{dep_i}f_{u,i}f_{v,j}+\sum_{j=0}^{i}f_{u,i}f_{v,j}+\sum_{j=0}^{i-1}f_{v,i}f_{u,j} \]

前缀和一下就可以做到 \(O(n^2)\)然而我线上赛时居然以为这是 \(O(n^3)\)

然后考虑转换成前缀和的形式,即 \(g_{i,j}=\sum_{k=0}^{j} f_{i,k}\)

那么:

\[f'_{u,i}=f_{u,i}\times (g_{v,dep_i}+ g_{v,i})+g_{u,i-1}\times f_{v,i} \]

考虑使用线段树合并,即在合并时先处理左半部分,处理完后顺便处理前缀和,然后用前缀和更新右半部分。

时间复杂度 \(O(n\log n)\)

#include<iostream>
#include<cstdio>
#include<cstring>
#define N 500010
#define mod 998244353
using namespace std;
int nxt[N<<1],to[N<<1],head[N],cnt;
void add(int u,int v)
{
    nxt[++cnt]=head[u];
    to[cnt]=v;
    head[u]=cnt;
}
int dep[N],fa[N];
int val[N*40],tag[N*40],ls[N*40],rs[N*40],tot;
void set_tag(int u,int v)
{
    if(!u) return;
    val[u]=1ll*val[u]*v%mod;tag[u]=1ll*tag[u]*v%mod;
}
void push_down(int u)
{
    if(tag[u]==1) return;
    set_tag(ls[u],tag[u]),set_tag(rs[u],tag[u]);tag[u]=1;
}
void insert(int &u,int l,int r,int p)
{
    u=++tot,tag[u]=val[u]=1;
    if(l==r) return;
    int mid=(l+r)>>1;
    if(p<=mid) insert(ls[u],l,mid,p);
    else insert(rs[u],mid+1,r,p);
}
int answer(int u,int l,int r,int L,int R)
{
    if(!u) return 0;
    if(L<=l && r<=R) return val[u];
    int mid=(l+r)>>1,res=0;
    push_down(u);
    if(L<=mid) res+=answer(ls[u],l,mid,L,R);
    if(R>mid) res+=answer(rs[u],mid+1,r,L,R);
    return res%mod;
}
int merge(int x,int y,int l,int r,int &s1,int &s2)
{
    if(!x || !y)
    {
		if(y)s1=(s1+val[y])%mod,set_tag(y,s2);
        if(x)s2=(s2+val[x])%mod,set_tag(x,s1);
		return x+y;
    }
    if(l==r)
    {
        int tx=val[x],ty=val[y];
        val[x]=(1ll*val[x]*(s1+val[y])+1ll*val[y]*s2)%mod;
        s1=(s1+ty)%mod;
        s2=(s2+tx)%mod;
        return x;
    }
    push_down(x),push_down(y);
    int mid=(l+r)>>1;
    ls[x]=merge(ls[x],ls[y],l,mid,s1,s2);
    rs[x]=merge(rs[x],rs[y],mid+1,r,s1,s2);
    val[x]=(val[ls[x]]+val[rs[x]])%mod;
    return x;
}
int up[N],root[N],n;
void dfs(int u,int p)
{
    dep[u]=dep[p]+1;fa[u]=p;
    for(int i=head[u];i;i=nxt[i])
    if(to[i]!=p) dfs(to[i],u);
}
void solve(int u)
{
    insert(root[u],0,n,up[u]);
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==fa[u]) continue;
        solve(v);
        int p=answer(root[v],0,n,0,dep[u]),p2=0;
        root[u]=merge(root[u],root[v],0,n,p,p2);
    }
}
int main()
{
    int m;
    scanf("%d",&n);
    for(int i=1;i<n;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y),add(y,x);
    }
    dfs(1,0);
    scanf("%d",&m);
    for(int i=1;i<=m;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        up[y]=max(up[y],dep[x]);
    }
    solve(1);
    printf("%d\n",answer(root[1],0,n,0,0));
    return 0;
}
posted @ 2020-08-20 10:46  Flying2018  阅读(187)  评论(0编辑  收藏  举报