CF1010F Tree

CF1010F Tree

重链剖分+\(NTT\)

直接考虑\(A_i\)并不容易,我们可以进行差分,具体来说,令:

\[B_i=A_i-\sum_{j \in son_i} A_j \]

这样我们只需要保证\(B_i>0\)即可,同时\(\sum B_i = x\)

根据插板法,如果有\(k\)个点,那么它的贡献就是\({x+k-1 \choose k-1}\)

接下来的任务就是对于每个\(k\),计算出方案数。

对于每个节点,建立答案的生成函数。

如果有两个子节点:

\[F_u(x)=x F_{v1} (x) F_{v2} (x)+1 \]

一个子节点:

\[F_u(x)=xF_v(x)+1 \]

无子节点:

\[F_u(x)=x+1 \]

利用重链剖分进行优化,进行链分治,首先计算出所有轻儿子的生成函数。

\(F_0=x\),从下至上的轻儿子生成函数乘上\(x\)\(F_1,F_2,\cdots,F_m\)(无轻儿子则\(F_i=x\))。

那么我们可以计算出重链顶端的生成函数:

\[F=(((F_0+1)F_1+1)F_2+1)\cdots +1\\ =F_0F_1F_2\cdots+F_1F_2\cdots+\cdots+1 \]

\(S=F_0F_1F_2\cdots+F_1F_2\cdots+\cdots+1,T=F_0F_1F_2\cdots\)

进行分治计算,计算出左右两部分的答案\(S_0,T_0,S_1,T_1\)

则:

\[S=(S_0-1)T_1+S_1\\ T=T_0T_1 \]

考虑一下复杂度上限,对于一颗大小为\(t\)的子树,我们会进行分治,分治合并时需要利用卷积,如果两个长度为\(n,m\)的多项式卷积复杂度近似看做\(O((n+m)\log x)\)(带了\(\log\)之后\(x\)是多少并不重要),那么我们单独考虑贡献,把分治当成一颗二叉树,对于每个叶子节点,若其大小为\(c\),那么它一共在\(\log t\)个节点有贡献,每次贡献看成\(c \log x\),总共的贡献就是\(c \log^2 x\),所以一颗大小为\(t\)的子树贡献就是\(O(t \log^2 t)\)

根据重链剖分轻子树大小总和为\(n \log n\),得出时间复杂度为\(O(n \log^3 n)\)

\(Code:\)

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<vector>
#define N 100005
#define ll long long
#define IT vector<int> :: iterator
using namespace std;
const int p=998244353;
int s,l,G[2][25],rev[N << 1];
void Add(int &x,int y)
{
    x=(x+y)%p;
}
void Del(int &x,int y)
{
    x=(x-y)%p;
}
void Mul(int &x,int y)
{
    x=(ll)x*y%p;
}
int add(int x,int y)
{
    return (x+y)%p;
}
int del(int x,int y)
{
    return (x-y)%p;
}
int mul(int x,int y)
{
    return (ll)x*y%p;
}
int ksm(int x,int y)
{
    int ans=1;
    while (y)
    {
        if (y & 1)
            Mul(ans,x);
        Mul(x,x);
        y >>=1;
    }
    return ans;
}
void Pre()
{
    G[0][23]=ksm(3,(p-1)/(1 << 23));
    G[1][23]=ksm(G[0][23],p-2);
    for (int i=22;i;--i)
    {
        G[0][i]=mul(G[0][i+1],G[0][i+1]);
        G[1][i]=mul(G[1][i+1],G[1][i+1]);
    }
}
void solve(int n)
{
    s=1,l=0;
    while (s<n)
        s <<=1,++l;
    for (int i=0;i<s;++i)
        rev[i]=(rev[i >> 1] >> 1) | ((i & 1) << l-1);
}
struct Poly
{
    int n;
    vector<int>a;
    int& operator [] (int x)
    {
        return a[x];
    }
    void read(int zn)
    {
        n=zn,a.clear();
        for (int i=0;i<n;++i)
            a.push_back(0),scanf("%d",&a[i]);
    }
    void print()
    {
        puts("--------------------");
        printf("Len: %d\n",n);
        for (int i=0;i<n;++i)
            printf("%d ",a[i]);
        putchar('\n');
        puts("--------------------");
    }
    void clean()
    {
        n=0,a.clear();
    }
    void reuse(int zn)
    {
        n=zn,a.clear();
        for (int i=0;i<n;++i)
            a.push_back(0);
    }
    void extend(int S=s)
    {
        int t=S-a.size();
        for (int i=1;i<=t;++i)
            a.push_back(0);
    }
    void rollback(int S)
    {
        int t=a.size()-S;
        for (int i=0;i<t;++i)
            a.pop_back();
    }
    void NTT(int t)
    {
        for (int i=0;i<s;++i)
            if (i<rev[i])
                swap(a[i],a[rev[i]]);
        for (int mid=1,o=1;mid<s;mid <<=1,++o)
            for (int j=0;j<s;j+=mid << 1)
            {
                int g=1;
                for (int k=0;k<mid;++k,Mul(g,G[t][o]))
                {
                    int x=a[j+k],y=mul(g,a[j+k+mid]);
                    a[j+k]=add(x,y);
                    a[j+k+mid]=del(x,y);
                }
            }
    }
    void minv(int S=s)
    {
        int t=ksm(S,p-2);
        for (int i=0;i<s;++i)
            Mul(a[i],t);
    }
};
Poly operator + (Poly f,Poly g)
{
    int n=max(f.n,g.n);
    f.n=n;
    f.extend(n),g.extend(n);
    for (int i=0;i<n;++i)
        Add(f[i],g[i]);
    return f;
}
void operator += (Poly &f,Poly &g)
{
    int n=max(f.n,g.n);
    f.n=n;
    f.extend(n),g.extend(n);
    for (int i=0;i<n;++i)
        Add(f[i],g[i]);
    g.rollback(g.n);
}
Poly operator * (Poly f,Poly g)
{
    int n=f.n,m=g.n;
    solve(n+m);
    f.extend(),g.extend();
    f.NTT(0),g.NTT(0);
    for (int i=0;i<s;++i)
        Mul(f[i],g[i]);
    f.NTT(1),f.minv();
    f.n=n+m-1,f.rollback(f.n);
    return f;
}
void operator *= (Poly &f,Poly &g)
{
    int n=f.n,m=g.n;
    solve(n+m);
    f.extend(),g.extend();
    f.NTT(0),g.NTT(0);
    for (int i=0;i<s;++i)
        Mul(f[i],g[i]);
    f.NTT(1),f.minv();
    f.n=n+m-1,f.rollback(f.n),g.rollback(m);
}
void polyswap(Poly &f,Poly &g)
{
    f.a.swap(g.a),swap(f.n,g.n);
}
int n,x,y,ans;
ll X;
struct edge
{
    int nxt,v;
    edge () {}
    edge (int Nxt,int V):nxt(Nxt),v(V) {}
}e[N << 1];
int tot,fr[N],sz[N],son[N],fa[N];
vector<int>H[N];
void link(int x,int y)
{
    ++tot;
    e[tot]=edge(fr[x],y),fr[x]=tot;
}
void dfs(int u)
{
    sz[u]=1;
    int mx=-1;
    for (int i=fr[u];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (v==fa[u])
            continue;
        fa[v]=u;
        dfs(v);
        sz[u]+=sz[v];
        if (sz[v]>mx)
            mx=sz[v],son[u]=v;
    }
}
void dfs2(int u,int tp)
{
    H[tp].push_back(u);
    if (!son[u])
        return;
    dfs2(son[u],tp);
    for (int i=fr[u];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (v==fa[u] || v==son[u])
            continue;
        dfs2(v,v);
    }
}
#define ls (p << 1)
#define rs (p << 1 | 1)
Poly Z,Z2,F[N],S[N << 2],T[N << 2];
void modify(int p,int l,int r,int x,Poly &a)
{
    if (l==r)
    {
        polyswap(S[p],a);
        a.clean();
        T[p]=S[p];
        ++S[p][0];
        return;
    }
    int mid=(l+r) >> 1;
    if (x<=mid)
        modify(ls,l,mid,x,a); else
        modify(rs,mid+1,r,x,a);
}
void calc(int p,int l,int r)
{
    if (l==r)
        return;
    int mid=(l+r) >> 1;
    calc(ls,l,mid);
    calc(rs,mid+1,r);
    --S[ls][0];
    S[p]=S[ls]*T[rs]+S[rs];
    T[p]=T[ls]*T[rs];
    S[ls].clean(),T[ls].clean();
    S[rs].clean(),T[rs].clean();
}
void Solve(int u)
{
    if (!son[u])
    {
        F[u].reuse(2);
        F[u][0]=F[u][1]=1;
        return;
    }
    int cnt=0;
    for (IT it=H[u].begin();it!=H[u].end();++it)
    {
        int v=*it;
        ++cnt;
        for (int i=fr[v];i;i=e[i].nxt)
        {
            int v2=e[i].v;
            if (v2==fa[v] || v2==son[v])
                continue;
            Solve(v2);
        }
    }
    --cnt;
    Z2=Z;
    modify(1,0,cnt,0,Z2);
    reverse(H[u].begin(),H[u].end());
    int rct=0;
    for (IT it=H[u].begin()+1;it!=H[u].end();++it)
    {
        ++rct;
        int v=*it;
        bool flag=false;
        for (int i=fr[v];i;i=e[i].nxt)
        {
            int v2=e[i].v;
            if (v2==fa[v] || v2==son[v])
                continue;
            flag=true;
            reverse(F[v2].a.begin(),F[v2].a.end());
            F[v2].a.push_back(0),++F[v2].n;
            reverse(F[v2].a.begin(),F[v2].a.end());
            modify(1,0,cnt,rct,F[v2]);
        }
        if (!flag)
            Z2=Z,modify(1,0,cnt,rct,Z2);
    }
    calc(1,0,cnt);
    polyswap(F[u],S[1]);
    S[1].clean(),T[1].clean();
}
int main()
{
    Pre();
    scanf("%d%lld",&n,&X);
    int zx=X%p;
    for (int i=1;i<n;++i)
    {
        scanf("%d%d",&x,&y);
        link(x,y),link(y,x);
    }
    dfs(1);
    dfs2(1,1);
    Z.reuse(2),Z[1]=1;
    Solve(1);
    int z1=1,z2=1;
    for (int i=1;i<=n;++i)
    {
        Add(ans,mul(mul(z1,z2),F[1][i]));
        Mul(z1,add(zx,i));
        Mul(z2,ksm(i,p-2));
    }
    ans=(ans%p+p)%p;
    printf("%d\n",ans);
    return 0;
}
posted @ 2021-02-21 16:15  GK0328  阅读(51)  评论(0编辑  收藏  举报