bzoj 5093 [Lydsy1711月赛]图的价值——第二类斯特林数

题目:https://www.lydsy.com/JudgeOnline/problem.php?id=5093

不要见到组合数就拆!

枚举每个点的度数,则答案为  \( n*\sum\limits_{i=0}^{n-1}C_{n-1}^{i}*2^{C_{n-1}^{2}}*i^{k} \)

(又是那个公式:\( x^{n}=\sum\limits_{k=0}^{n}C_{x}^{k}*(k!)*S(n,k) \))

               \( = n*2^{C_{n-1}^{2}}\sum\limits_{i=0}^{n-1}C_{n-1}^{i}\sum\limits_{j=0}^{k}C_{i}^{j}*(j!)*S(k,j) \)

这里发现组合数的角标有一样的,不要把那两个组合数拆了以消掉阶乘,而可以通过组合意义把它们合起来!

               \( = n*2^{C_{n-1}^{2}}\sum\limits_{j=0}^{k}(j!)*S(k,j)\sum\limits_{i=0}^{n-1}C_{n-1}^{i}*C_{i}^{j} \)

从 n-1 个数里选 i 个数,再从 i 个数里选 j 个数,而且 i 从 0 枚举到 n-1 ,就可以看作从 n-1 个数里选了 j 个数,剩下  n-1-j  个数可选可不选。

  (比如一个点 1 想连到另一个点 2 , 1 先在 n-1 个点里选 i 个点连上,再从这 i 个点里选 j 个点连到点 2 ; 也即点 1 在 n-1 个点里选了 j 个点连向点 2 ,其余的点可能和点 1 相连)

所以             \( = n*2^{C_{n-1}^{2}}\sum\limits_{j=0}^{k}(j!)*S(k,j)*C_{n-1}^{j}*2^{n-1-j} \)

用 NTT 预处理斯特林数就行了。别把现在的组合数拆掉,因为是一维特别大,一维特别小,所以分子和分母消一下。需要预处理阶乘和下降幂,才能做到 O(1) 算组合数。

注意指数上模 mod-1 。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int N=2e5+5,M=(1<<19)+5,mod=998244353;
int n,m,s[N],a[M],b[M],jcn[N],ljc[N],len,r[M];
void upd(int &x){x>=mod?x-=mod:0;}
int pw(int x,int k)
{int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;}
void ntt(int *a,bool fx)
{
  for(int i=0;i<len;i++)
    if(i<r[i])swap(a[i],a[r[i]]);
  for(int R=2;R<=len;R<<=1)
    {
      int wn=pw( 3,fx?(mod-1)-(mod-1)/R:(mod-1)/R );
      for(int i=0,m=R>>1;i<len;i+=R)
    for(int j=0,w=1;j<m;j++,w=(ll)w*wn%mod)
      {
        int x=a[i+j], y=(ll)w*a[i+m+j]%mod;
        a[i+j]=x+y;  upd(a[i+j]);
        a[i+m+j]=x+mod-y;  upd(a[i+m+j]);
      }
    }
  if(!fx)return ; int inv=pw(len,mod-2);
  for(int i=0;i<len;i++)a[i]=(ll)a[i]*inv%mod;
}
void init()
{
  jcn[0]=1;for(int i=1;i<=m;i++)jcn[i]=(ll)jcn[i-1]*i%mod;
  jcn[m]=pw(jcn[m],mod-2);for(int i=m-1;i>=0;i--)jcn[i]=(ll)jcn[i+1]*(i+1)%mod;
  for(int i=0,j=1;i<=m;i++,j=-j)
    a[i]=j*jcn[i]+mod,upd(a[i]);
  for(int i=0;i<=m;i++)
    b[i]=(ll)pw(i,m)*jcn[i]%mod;
  for(len=1;len<=m<<1;len<<=1);
  for(int i=0;i<len;i++)r[i]=(r[i>>1]>>1)+((i&1)?len>>1:0);
  ntt(a,0); ntt(b,0);
  for(int i=0;i<len;i++)a[i]=(ll)a[i]*b[i]%mod;
  ntt(a,1);
  for(int i=0;i<=m;i++)s[i]=a[i];

  ljc[0]=1;for(int i=n-1,j=1;j<=m;j++,i--)ljc[j]=(ll)ljc[j-1]*i%mod;
}
int C(int m)
{
  return (ll)ljc[m]*jcn[m]%mod;
}
int main()
{
  scanf("%d%d",&n,&m);
  init();
  int ans=0;
  for(int i=0,jc=1;i<=m;i++,jc=(ll)jc*i%mod)
    {
      if(i>n-1)break;//or pw()
      ans=(ans+(ll)jc*C(i)%mod*pw(2,n-1-i)%mod*s[i])%mod;
    }
  ans=(ll)ans*n%mod*pw(2,(ll)(n-1)*(n-2)/2%(mod-1))%mod;//mod-1
  printf("%d\n",ans);
  return 0;
}

 

posted on 2018-12-05 13:04  Narh  阅读(143)  评论(0编辑  收藏  举报

导航