【题解】CodeChef - TREDEG (prufer+生成函数+多项式exp)
【题解】CodeChef - TREDEG (prufer+生成函数+多项式exp)
好毒瘤的数据范围...
先转prufer,现在问题就变成了我要生成一个\(n-2\)长度的序列,每一种序列的权值定义为每种数的\(\prod\)(每种数出现个数+1),可以直接使用指数型生成函数生成,具体的:
\[(\sum_{i=0} {(i+1)^k\over i!}x^i)^n[x^{n-2}](n-2)!
\]
这个就生成这个序列的答案了。用exp搞个快速幂就完事了。
最终答案的式子
\[(\sum_{i=0} {(i+1)^k\over i!}x^i)^n[x^{n-2}](n-2)!\over n^{n-2}
\]
然后数据范围要我们单独做\(k=1\),那么把\((\sum_{i=0} {i+1\over i!}x^i)^n[x^{n-2}]\)单独拿出来
\[(\sum_{i=0} {i+1\over i!}x^i)^n[x^{n-2}]
\]
用\(e^x\)代替
\[(xe^x+e^x)^n[x^{n-2}]
\]
二项式定理展开
\[[x^{n-2}]\sum {n\choose i} x^ie^{ix}e^{(n-i)x}
\]
合并
\[[x^{n-2}]\sum {n\choose i} x^ie^{nx}
\]
化简一下
\[\sum {n\choose i} e^{nx}[x^{n-2-i}]
\]
泰勒展开一下
\[\sum {n\choose i} {n^{n-2-i}\over (n-2-i)!}
\]
就可以\(O(n)\)算了。
代码:(很长)
//@winlere
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
//#define getchar() (__c==__ed?(__ed=__buf+fread(__c=__buf,1,1<<18,stdin),*__c++):*__c++)
using namespace std; typedef long long ll; char __buf[1<<18],*__c=__buf,*__ed=__buf;
inline int qr(){
int ret=0,f=0,c=getchar();
while(!isdigit(c)) f|=c==45,c=getchar();
while( isdigit(c)) ret=ret*10+c-48,c=getchar();
return f?-ret:ret;
}
const int maxn=1<<22;
const int mod=998244353;
const int g=3;
const int gi=(mod+1)/3;
inline int MOD(const int&x){return x>=mod?x-mod:x;}
inline int MOD(const int&x,const int&y){return 1ll*x*y%mod;}
int invs[maxn],jc[maxn],inv[maxn];
int ksm(const int&ba,const int&p){
int ret=1;
for(int t=p,b=ba;t;t>>=1,b=MOD(b,b))
if(t&1) ret=MOD(ret,b);
return ret;
}
void NTT(int*a,const int&len,const int&tag){
static int r[maxn];
for(int t=1;t<len;++t)
if((r[t]=r[t>>1]>>1|(t&1?len>>1:0))>t)
swap(a[t],a[r[t]]);
for(int t=1,s=tag==1?g:gi,wn;t<len;t<<=1){
wn=ksm(s,(mod-1)/(t<<1));
for(int i=0;i<len;i+=t<<1)
for(int j=0,w=1,p;j<t;++j,w=MOD(w,wn))
p=MOD(a[i+j+t],w),a[i+j+t]=MOD(a[i+j]-p+mod),a[i+j]=MOD(a[i+j]+p);
}
if(tag!=1)
for(int t=0,i=mod-(mod-1)/len;t<len;++t)
a[t]=MOD(a[t],i);
}
void Deri(int*a,const int&len){
for(int t=0;t<len-1;++t) a[t]=MOD(a[t+1],t+1);
a[len-1]=0;
}
void Inter(int*a,int*b,const int&len){
for(int t=len-1;t;--t) b[t]=MOD(a[t-1],invs[t]);
b[0]=0;
}
void INV(int*a,int*b,const int&len){
if(len==1) return b[0]=ksm(a[0],mod-2),void();
INV(a,b,len>>1);
static int A[maxn],B[maxn];
memset(A,0,len<<3); memset(B,0,len<<3);
memcpy(A,a,len<<2); memcpy(B,b,len<<2);
NTT(A,len<<1,1); NTT(B,len<<1,1);
for(int t=0;t<len<<1;++t) A[t]=MOD(B[t],MOD(A[t],B[t]));
NTT(A,len<<1,0);
for(int t=0;t<len;++t) b[t]=MOD(MOD(b[t]+b[t])-A[t]+mod);
}
void LN(int*a,int*b,const int&len){
static int A[maxn],B[maxn];
memset(A,0,len<<3); memset(B,0,len<<3);
memcpy(A,a,len<<2);
INV(A,B,len); Deri(A,len);
NTT(A,len<<1,1); NTT(B,len<<1,1);
for(int t=0;t<len<<1;++t) A[t]=MOD(A[t],B[t]);
NTT(A,len<<1,0);
Inter(A,b,len);
}
void EXP(int*a,int*b,const int&len){
if(len==1){b[0]=1;return;}
EXP(a,b,len>>1);
static int A[maxn],B[maxn];
memset(A,0,len<<3); memset(B,0,len<<3);
memcpy(A,b,len<<1); LN(b,B,len);
for(int t=0;t<len;++t) B[t]=MOD(a[t]-B[t]+mod);
++B[0];
NTT(A,len<<1,1); NTT(B,len<<1,1);
for(int t=0;t<len<<1;++t) A[t]=MOD(A[t],B[t]);
NTT(A,len<<1,0);
for(int t=0;t<len;++t) b[t]=A[t];
}
void POW(int*a,int*b,const int&len,const int&k){
static int A[maxn],B[maxn];
memset(A,0,len<<3); memcpy(A,a,len<<2);
memset(B,0,len<<3);
LN(A,A,len);
for(int t=0;t<len;++t) A[t]=MOD(A[t],k);
EXP(A,B,len);
memcpy(b,B,len<<2);
}
void pre(const int&n){
jc[0]=invs[1]=inv[0]=1;
for(int t=1;t<=n;++t) jc[t]=MOD(jc[t-1],t);
for(int t=2;t<=n;++t) invs[t]=MOD(mod-mod/t,invs[mod%t]);
for(int t=1;t<=n;++t) inv[t]=MOD(inv[t-1],invs[t]);
//for(int t=0;t<=n;++t) if(MOD(jc[t],inv[t])!=1) puts("wa"),cerr<<"t="<<t<<endl;
}
int c(const int&n,const int&m){
if(n<m) return 0;
return MOD(jc[n],MOD(inv[m],inv[n-m]));
}
int main(){
pre(maxn-1);
#ifdef debug
static int test[maxn],tesv[maxn];
for(int t=0;t<4;++t) test[t]=t+1;
NTT(test,8,1);
NTT(test,8,0);
for(int t=0;t<16;++t) fprintf(stderr,"%d%c",test[t]," \n"[t==15]);
INV(test,tesv,4);
NTT(test,8,1); NTT(tesv,8,1);
for(int t=0;t<8;++t) test[t]=MOD(test[t],tesv[t]);
NTT(test,8,0);
for(int t=0;t<8;++t) fprintf(stderr,"%d%c",test[t]," \n"[t==7]);
for(int t=0;t<8;++t) test[t]=0;
for(int t=0;t<4;++t) test[t]=t+1;
Deri(test,4);
for(int t=0;t<8;++t) fprintf(stderr,"%d%c",test[t]," \n"[t==7]);
Inter(test,test,4);
Deri(test,4);
for(int t=0;t<8;++t) fprintf(stderr,"%d%c",test[t]," \n"[t==7]);
for(int t=0;t<8;++t) test[t]=0;
for(int t=0;t<4;++t) test[t]=t+1;
POW(test,test,8,2);
for(int t=0;t<8;++t) fprintf(stderr,"%d%c",test[t]," \n"[t==7]);
#endif
int T=qr();
while(T--){
int n=qr(),k=qr();
if(k!=1){
static int a[maxn];
int len=1;
while(len<=n) len<<=1;
memset(a,0,len<<2);
for(int t=0;t<=n;++t) a[t]=MOD(inv[t],ksm(t+1,k));
POW(a,a,len,n);
//cerr<<"a[n-2]="<<a[n-2]<<endl;
int ans=MOD(MOD(a[n-2],jc[n-2]),ksm(ksm(n,n-2),mod-2));
printf("%d\n",ans);
}else{
int ans=0;
for(int t=0;t<=n-2;++t)
ans=MOD(ans+MOD(c(n,t),MOD(ksm(n,n-2-t),inv[n-2-t])));
ans=MOD(ans,MOD(jc[n-2],ksm(ksm(n,n-2),mod-2)));
printf("%d\n",ans);
}
}
return 0;
}
博客保留所有权利,谢绝学步园、码迷等不在文首明显处显著标明转载来源的任何个人或组织进行转载!其他文明转载授权且欢迎!