【题解】CodeChef - TREDEG (prufer+生成函数+多项式exp)

【题解】CodeChef - TREDEG (prufer+生成函数+多项式exp)

好毒瘤的数据范围...

先转prufer,现在问题就变成了我要生成一个\(n-2\)长度的序列,每一种序列的权值定义为每种数的\(\prod\)(每种数出现个数+1),可以直接使用指数型生成函数生成,具体的:

\[(\sum_{i=0} {(i+1)^k\over i!}x^i)^n[x^{n-2}](n-2)! \]

这个就生成这个序列的答案了。用exp搞个快速幂就完事了。

最终答案的式子

\[(\sum_{i=0} {(i+1)^k\over i!}x^i)^n[x^{n-2}](n-2)!\over n^{n-2} \]

然后数据范围要我们单独做\(k=1\),那么把\((\sum_{i=0} {i+1\over i!}x^i)^n[x^{n-2}]\)单独拿出来

\[(\sum_{i=0} {i+1\over i!}x^i)^n[x^{n-2}] \]

\(e^x\)代替

\[(xe^x+e^x)^n[x^{n-2}] \]

二项式定理展开

\[[x^{n-2}]\sum {n\choose i} x^ie^{ix}e^{(n-i)x} \]

合并

\[[x^{n-2}]\sum {n\choose i} x^ie^{nx} \]

化简一下

\[\sum {n\choose i} e^{nx}[x^{n-2-i}] \]

泰勒展开一下

\[\sum {n\choose i} {n^{n-2-i}\over (n-2-i)!} \]

就可以\(O(n)\)算了。

代码:(很长)

//@winlere
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
//#define getchar() (__c==__ed?(__ed=__buf+fread(__c=__buf,1,1<<18,stdin),*__c++):*__c++)

using namespace std;  typedef long long ll;   char __buf[1<<18],*__c=__buf,*__ed=__buf;
inline int qr(){
	int ret=0,f=0,c=getchar();
	while(!isdigit(c)) f|=c==45,c=getchar();
	while( isdigit(c)) ret=ret*10+c-48,c=getchar();
	return f?-ret:ret;
}

const int maxn=1<<22;
const int mod=998244353;
const int g=3;
const int gi=(mod+1)/3;
inline int MOD(const int&x){return x>=mod?x-mod:x;}
inline int MOD(const int&x,const int&y){return 1ll*x*y%mod;}
int invs[maxn],jc[maxn],inv[maxn];
int ksm(const int&ba,const int&p){
	int ret=1;
	for(int t=p,b=ba;t;t>>=1,b=MOD(b,b))
		if(t&1) ret=MOD(ret,b);
	return ret;
}
void NTT(int*a,const int&len,const int&tag){
	static int r[maxn];
	for(int t=1;t<len;++t)
		if((r[t]=r[t>>1]>>1|(t&1?len>>1:0))>t)
			swap(a[t],a[r[t]]);
	for(int t=1,s=tag==1?g:gi,wn;t<len;t<<=1){
		wn=ksm(s,(mod-1)/(t<<1));
		for(int i=0;i<len;i+=t<<1)
			for(int j=0,w=1,p;j<t;++j,w=MOD(w,wn))
				p=MOD(a[i+j+t],w),a[i+j+t]=MOD(a[i+j]-p+mod),a[i+j]=MOD(a[i+j]+p);
	}
	if(tag!=1)
		for(int t=0,i=mod-(mod-1)/len;t<len;++t)
			a[t]=MOD(a[t],i);
}
void Deri(int*a,const int&len){
	for(int t=0;t<len-1;++t) a[t]=MOD(a[t+1],t+1);
	a[len-1]=0;	
}
void Inter(int*a,int*b,const int&len){
	for(int t=len-1;t;--t) b[t]=MOD(a[t-1],invs[t]);
	b[0]=0;
}
void INV(int*a,int*b,const int&len){
	if(len==1) return b[0]=ksm(a[0],mod-2),void();
	INV(a,b,len>>1);
	static int A[maxn],B[maxn];
	memset(A,0,len<<3); memset(B,0,len<<3);
	memcpy(A,a,len<<2); memcpy(B,b,len<<2);
	NTT(A,len<<1,1); NTT(B,len<<1,1);
	for(int t=0;t<len<<1;++t) A[t]=MOD(B[t],MOD(A[t],B[t]));
	NTT(A,len<<1,0);
	for(int t=0;t<len;++t) b[t]=MOD(MOD(b[t]+b[t])-A[t]+mod);
}

void LN(int*a,int*b,const int&len){
	static int A[maxn],B[maxn];
	memset(A,0,len<<3); memset(B,0,len<<3);
	memcpy(A,a,len<<2);
	INV(A,B,len); Deri(A,len);
	NTT(A,len<<1,1); NTT(B,len<<1,1);
	for(int t=0;t<len<<1;++t) A[t]=MOD(A[t],B[t]);
	NTT(A,len<<1,0);
	Inter(A,b,len);
}
  
void EXP(int*a,int*b,const int&len){
	if(len==1){b[0]=1;return;}
	EXP(a,b,len>>1);
	static int A[maxn],B[maxn];
	memset(A,0,len<<3); memset(B,0,len<<3);
	memcpy(A,b,len<<1); LN(b,B,len);
	for(int t=0;t<len;++t) B[t]=MOD(a[t]-B[t]+mod);
	++B[0];
	NTT(A,len<<1,1); NTT(B,len<<1,1);
	for(int t=0;t<len<<1;++t) A[t]=MOD(A[t],B[t]);
	NTT(A,len<<1,0);
	for(int t=0;t<len;++t) b[t]=A[t];
}

void POW(int*a,int*b,const int&len,const int&k){
	static int A[maxn],B[maxn];
	memset(A,0,len<<3); memcpy(A,a,len<<2);
	memset(B,0,len<<3);
	LN(A,A,len);
	for(int t=0;t<len;++t) A[t]=MOD(A[t],k);
	EXP(A,B,len);
	memcpy(b,B,len<<2);
}

void pre(const int&n){
	jc[0]=invs[1]=inv[0]=1;
	for(int t=1;t<=n;++t) jc[t]=MOD(jc[t-1],t);
	for(int t=2;t<=n;++t) invs[t]=MOD(mod-mod/t,invs[mod%t]);
	for(int t=1;t<=n;++t) inv[t]=MOD(inv[t-1],invs[t]);
	//for(int t=0;t<=n;++t) if(MOD(jc[t],inv[t])!=1) puts("wa"),cerr<<"t="<<t<<endl;
}

int c(const int&n,const int&m){
	if(n<m) return 0;
	return MOD(jc[n],MOD(inv[m],inv[n-m]));
}

int main(){
	pre(maxn-1);
#ifdef debug
	static int test[maxn],tesv[maxn];
	for(int t=0;t<4;++t) test[t]=t+1;
	NTT(test,8,1);
	NTT(test,8,0);
	for(int t=0;t<16;++t) fprintf(stderr,"%d%c",test[t]," \n"[t==15]);
	INV(test,tesv,4);
	NTT(test,8,1); NTT(tesv,8,1);
	for(int t=0;t<8;++t) test[t]=MOD(test[t],tesv[t]);
	NTT(test,8,0);
	for(int t=0;t<8;++t) fprintf(stderr,"%d%c",test[t]," \n"[t==7]);
	for(int t=0;t<8;++t) test[t]=0;
	for(int t=0;t<4;++t) test[t]=t+1;
	Deri(test,4);
	for(int t=0;t<8;++t) fprintf(stderr,"%d%c",test[t]," \n"[t==7]);
	Inter(test,test,4);
	Deri(test,4);
	for(int t=0;t<8;++t) fprintf(stderr,"%d%c",test[t]," \n"[t==7]);
	for(int t=0;t<8;++t) test[t]=0;
	for(int t=0;t<4;++t) test[t]=t+1;
	POW(test,test,8,2);
	for(int t=0;t<8;++t) fprintf(stderr,"%d%c",test[t]," \n"[t==7]);
#endif
	int T=qr();
	while(T--){
		int n=qr(),k=qr();
		if(k!=1){
			static int a[maxn];
			int len=1;
			while(len<=n) len<<=1;
			memset(a,0,len<<2);
			for(int t=0;t<=n;++t) a[t]=MOD(inv[t],ksm(t+1,k));
			POW(a,a,len,n);
			//cerr<<"a[n-2]="<<a[n-2]<<endl;
			int ans=MOD(MOD(a[n-2],jc[n-2]),ksm(ksm(n,n-2),mod-2));
			printf("%d\n",ans);
		}else{
			int ans=0;
			for(int t=0;t<=n-2;++t)
				ans=MOD(ans+MOD(c(n,t),MOD(ksm(n,n-2-t),inv[n-2-t])));
			ans=MOD(ans,MOD(jc[n-2],ksm(ksm(n,n-2),mod-2)));
			printf("%d\n",ans);			
		}
	}
	return 0;
}
posted @ 2020-01-30 21:45  谁是鸽王  阅读(312)  评论(0编辑  收藏  举报