Loading

NOIP 模拟 $52\; \rm 路径$

题解 \(by\;zj\varphi\)

本质上可以理解为求长度为 \(x\) 的路径有多少条,最后 \(k\) 次方即可。

考虑点分治子树合并,设 \(ans_x\) 表示答案中长度为 \(x\) 的路径有多少条,那么:

\[ans_x=\sum_{i=0}^xdepa_i*depb_{x-i} \]

其中 \(depa_i\) 表示已合并的子树中到分治中心长度为 \(i\) 的点个数,\(depb\) 表示当前子树。

发现这是个卷积形式,且模数是 \(998244353\),直接 \(\rm NTT\) 卷一下,但是会被卡常,用 unsigned long long 优化一下 \(\rm NTT\) 即可。

复杂度 \(\mathcal O\rm (nlog^2n)\)

Code
#include<bits/stdc++.h>
#define Re register
#define ri Re signed
#define pd(i) ++i
#define bq(i) --i
namespace IO{
    char buf[1<<21],*p1=buf,*p2=buf;
    #define gc() p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?(-1):*p1++
    struct nanfeng_stream{
        template<typename T>inline nanfeng_stream &operator>>(T &x) {
            Re bool f=false;x=0;Re char ch=gc();
            while(!isdigit(ch)) f|=ch=='-',ch=gc();
            while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=gc();
            return x=f?-x:x,*this;
        }
    }cin;
}
using IO::cin;
namespace nanfeng{
    #define FI FILE *IN
    #define FO FILE *OUT
    template<typename T>inline T cmax(T x,T y) {return x>y?x:y;}
    template<typename T>inline T cmin(T x,T y) {return x>y?y:x;}
    using ll=long long;
    using ull=unsigned long long;
    static const int N=1<<21,MOD=998244353;
    struct edge{int v,nxt;}e[N<<1];
    int first[N],siz[N],dep[N],R[N],n,k,t=1,cnt,pos,cmp,mx;
    bool G[N];
    ll w1[N],w2[N],ans[N],as;
    ull a[N],b[N],c[N],d[N];
    auto add=[](int u,int v) {
        e[t].v=v,e[t].nxt=first[u],first[u]=t++;
        e[t].v=u,e[t].nxt=first[v],first[v]=t++;
    };
    auto MD=[](ll x) {return x>=MOD?x-MOD:x;};
    auto fpow=[](ll x,int y) {
        ll res=1;
        while(y) {
            if (y&1) res=res*x%MOD;
            x=x*x%MOD;
            y>>=1;
        }
        return res;
    };
    void dfs_find(int x,int fa) {
        siz[x]=1;
        int GS(0);
        for (ri i(first[x]),v;i;i=e[i].nxt) {
            if ((v=e[i].v)==fa||G[v]) continue;
            dfs_find(v,x);
            GS=cmax(siz[v],GS);
            siz[x]+=siz[v];
        }
        GS=cmax(GS,cmp-siz[x]);
        if (GS<cnt) cnt=GS,pos=x;
    }
    void dfs_solve(int x,int fa) {
        siz[x]=1;
        for (ri i(first[x]),v;i;i=e[i].nxt) {
            if ((v=e[i].v)==fa||G[v]) continue;
            dep[v]=dep[x]+1;
            ++b[dep[v]];
            mx=cmax(mx,dep[v]);
            dfs_solve(v,x);
            siz[x]+=siz[v];
        }
    }
    void solve(int x,int S) {
        cmp=cnt=S;
        dfs_find(x,0);
        int np;
        G[np=pos]=true;
        int mxp=0;
        a[0]=1;
        for (ri i(first[np]),v;i;i=e[i].nxt) {
            if (G[v=e[i].v]) continue;
            mx=dep[v]=1;
            b[1]=1;
            dfs_solve(v,np);
            auto calc=[mxp]() {
                int st=1,len=0;
                while(st<=mxp+mx+2) st<<=1,++len;
                int inv=fpow(st,MOD-2);
                w1[1]=fpow(3,(MOD-1)/st),w2[1]=fpow(w1[1],MOD-2);
                for (ri i(2);i<st;pd(i)) w1[i]=w1[i-1]*w1[1]%MOD,w2[i]=w2[i-1]*w2[1]%MOD;
                for (ri i(0);i<st;pd(i)) R[i]=(R[i>>1]>>1)|((i&1)<<(len-1));
                auto NTT1=[st](ull *a) {
                    for (ri i(0);i<st;pd(i)) if (R[i]>i) std::swap(a[R[i]],a[i]);
                    for (ri t(st>>1),d(1);d<st;t>>=1,d<<=1)
                        for (ri i(0);i<st;i+=d<<1)
                            for (ri j(0);j<d;pd(j)) {
                                const ll tmp=w1[t*j]*a[i+j+d]%MOD;
                                a[i+j+d]=a[i+j]-tmp+MOD;
                                a[i+j]=a[i+j]+tmp;
                            }
                        for (ri i(0);i<st;pd(i)) a[i]%=MOD;
                };
                auto NTT2=[st,inv](ull *a) {
                    for (ri i(0);i<st;pd(i)) if (R[i]>i) std::swap(a[R[i]],a[i]);
                    for (ri t(st>>1),d(1);d<st;t>>=1,d<<=1)
                        for (ri i(0);i<st;i+=d<<1)
                            for (ri j(0);j<d;pd(j)) {
                                const ll tmp=w2[t*j]*a[i+j+d]%MOD;
                                a[i+j+d]=a[i+j]-tmp+MOD;
                                a[i+j]=a[i+j]+tmp;
                            } 
                    for (ri i(0);i<st;pd(i)) a[i]=a[i]%MOD*inv%MOD;
                };
                memcpy(c,a,sizeof(ull)*(mxp+1));
                memset(c+mxp+1,0,sizeof(ull)*(st-mxp));
                memcpy(d,b,sizeof(ull)*(mx+1));
                memset(d+mx+1,0,sizeof(ull)*(st-mx));
                NTT1(c),NTT1(d);
                for (ri i(0);i<st;pd(i)) c[i]=c[i]*d[i]%MOD;
                NTT2(c);
                for (ri i(1);i<st;pd(i)) ans[i]=ans[i]+c[i];
            };
            calc();
            mxp=cmax(mxp,mx);
            for (ri j(1);j<=mx;pd(j)) a[j]+=b[j],b[j]=0;
        }
        memset(a,0,sizeof(ll)*(mxp+1));
        for (ri i(first[np]),v;i;i=e[i].nxt) {
            if (G[v=e[i].v]) continue;
            solve(v,siz[v]);
        }
    }
    inline int main() {
        // FI=freopen("nanfeng.in","r",stdin);
        // FO=freopen("nanfeng.out","w",stdout);
        cin >> n >> k;
        for (ri i(1),u,v;i<n;pd(i)) cin >> u >> v,add(u,v);
        w1[0]=w2[0]=1;
        solve(1,n);
        for (ri i(1);i<n;pd(i)) as=MD(as+ans[i]%MOD*fpow(i,k)%MOD);
        printf("%lld\n",as);
        return 0;
    }
}
int main() {return nanfeng::main();}
posted @ 2021-09-15 09:06  ナンカエデ  阅读(55)  评论(0编辑  收藏  举报