Luogu P5296 [北京省选集训2019]生成树计数

Luogu P5296 [北京省选集训2019]生成树计数

题目链接

题目大意:给定每条边的边权。一颗生成树的权值为边权和的\(k\)次方。求出所有生成树的权值和。

我们列出答案的式子:

\(E\)为我们枚举的生成树的边集。

\[Ans=\sum_{E}(\sum_{i\in E}w_i)^k\\ =\sum_E \prod_{i\in E} \binom{k}{a_i}w_i^{a_i}[\sum_{i\in E}a_i=k]\\ =\sum_E \frac{1}{k!} \prod_{i\in E} \frac{1}{a_i!} w_i^{a_i}[\sum_{i\in E}a_i=k] \]

我们知道,基尔霍夫矩阵求出来的东西是:

\[\sum_{E}\prod_{i\in E}w_i \]

但是对于上面那个式子,我们发现每条边其实是个多项式:

\[w(x)=\sum_{i=0}^k\frac{1}{i!}w^i \]

进一步发现,最终答案的多项式的项数是\(n*k\)(大概吧)。

于是我们带入大于\(n*k+1\)个值进去,用矩阵树定理算出对应的值,然后拉格朗日插值暴力算出第\(k\)项的系数(应该有更好的方法)。

复杂度:\(O(n^4k)\)

代码:

#include<bits/stdc++.h>
#define ll long long
#define N 35

using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}

const ll mod=998244353;
ll ksm(ll t,ll x) {
    ll ans=1;
    for(;x;x>>=1,t=t*t%mod)
        if(x&1) ans=ans*t%mod;
    return ans;
}

int n,m,k;
ll w[N][N];
ll a[N][N];
ll val[N][N];
ll g[N][N][N];
ll f[N*N];
ll fac[N*N],ifac[N*N];
ll Gauss(ll a[N][N],int n) {
    ll ans=1;
    for(int i=2;i<=n;i++) {
        for(int j=i;j<=n;j++) {
            if(a[j][i]) {
                if(i!=j) {
                    ans=ans*(mod-1)%mod;
                    swap(a[i],a[j]);
                }
                break;
            }
        }
        ans=ans*a[i][i]%mod;
        ll inv=ksm(a[i][i],mod-2);
        for(int j=i+1;j<=n;j++) {
            ll tem=inv*a[j][i]%mod;
            for(int k=i;k<=n;k++) a[j][k]=(a[j][k]-tem*a[i][k]%mod+mod)%mod;
        }
    }
    return ans;
}

ll dp[N*N];

void Insert(int v) {
    for(int i=m;i>=0;i--) {
        dp[i]=dp[i]*(mod-v)%mod;
        if(i) (dp[i]+=dp[i-1])%=mod;
    }
}

void Del(int v) {
    for(int i=0;i<=m;i++) {
        if(i) (dp[i]=dp[i]-dp[i-1]+mod);
        dp[i]=dp[i]*ksm(mod-v,mod-2)%mod;
    }
}

int main() {
    n=Get(),k=Get();
    m=n*k+3;
    fac[0]=1;
    for(int i=1;i<=m;i++) fac[i]=fac[i-1]*i%mod;
    ifac[m]=ksm(fac[m],mod-2);
    for(int i=m-1;i>=0;i--) ifac[i]=ifac[i+1]*(i+1)%mod;
    
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            w[i][j]=Get();
    for(int i=1;i<=n;i++) {
        for(int j=i+1;j<=n;j++) {
            for(int q=0;q<=k;q++) {
                g[i][j][q]=ksm(w[i][j],q)*ifac[q]%mod;
            }
        }
    }
    for(int x0=1;x0<=m;x0++) {
        for(int i=1;i<=n;i++) {
            for(int j=i+1;j<=n;j++) {
                val[i][j]=0;
                ll now=1;
                for(int q=0;q<=k;q++) {
                    (val[i][j]+=g[i][j][q]*now)%=mod;
                    now=now*x0%mod;
                }
            }
        }
        memset(a,0,sizeof(a));
        for(int i=1;i<=n;i++) {
            for(int j=i+1;j<=n;j++) {
                a[i][i]+=val[i][j];
                a[j][j]+=val[i][j];
                a[i][j]-=val[i][j];
                a[j][i]-=val[i][j];
            }
        }
        for(int i=1;i<=n;i++)
            for(int j=1;j<=n;j++)
                a[i][j]=(a[i][j]%mod+mod)%mod;
        f[x0]=Gauss(a,n);
    }
    dp[0]=1;
    for(int i=1;i<=m;i++) Insert(i);
    ll ans=0;
    for(int i=1;i<=m;i++) {
        Del(i);
        ll now=1;
        for(int j=1;j<=m;j++)
            if(i!=j) now=now*(i-j)%mod;
        now=ksm(now%mod+mod,mod-2);
        (ans+=now*dp[k]%mod*f[i])%=mod;
        Insert(i);
    }
    cout<<ans*fac[k]%mod<<"\n";
    return 0;
}

posted @ 2019-04-17 14:45  hec0411  阅读(378)  评论(0编辑  收藏  举报