LOJ6386题解

对于 \(\sum_{i=0}^{n}f(i)\) 的这种问题但是 \(f(i)\) 不是多项式函数且 \(n\) 很大时可以考虑一个用矩阵做的 DP:

\[\begin{bmatrix}\binom{n}{m}\\\sum_{i=0}^{m}\binom{n}{i}\end{bmatrix}=\begin{bmatrix}\frac{n-m+1}{m}&0\\\frac{n-m+1}{m}&1\end{bmatrix}\begin{bmatrix}\binom{n}{m-1}\\\sum_{i=0}^{m-1}\binom{n}{i}\end{bmatrix} \]

考虑这个矩阵的乘法:

\[\begin{bmatrix}a_1&0\\a_2&1\end{bmatrix}\begin{bmatrix}b_1&0\\b_2&1\end{bmatrix}=\begin{bmatrix}a_1b_1&0\\a_2b_1+b_2&1\end{bmatrix} \]

我们设这个矩阵是:

\[\begin{bmatrix}\frac{G(x)}{F(x)}&0\\\frac{H(x)}{F(x)}&1\end{bmatrix} \]

那么上述矩阵运算可以被改写为:

\[F_3(x)=F_1(x)F_2(x) \]

\[G_3(x)=G_1(x)G_2(x) \]

\[H_3(x)=G_1(x)H_2(x)+H_1(x)F_2(x) \]

为什么和上面是反的是因为这个矩阵乘法是左乘。

使用类似 快速阶乘算法 和 快速调和级数求和 那样的倍增就好了。。。复杂度 \(O(\sqrt{n}\log n)\),常数可能有亿点大。

#include<cstdio>
#include<cmath>
#define IMP(lim,act) for(int qwq=(lim),i=0;i^qwq;++i)act
const int M=1<<17|5,mod=998244353;
int ifac[M<<2],buf[M<<2],*w[20];
inline int Getlen(const int&n){
	int len(0);while((1<<len)<n)++len;return len;
}
inline int Add(const int&a,const int&b){
	return a+b>=mod?a+b-mod:a+b;
}
inline int Del(const int&a,const int&b){
	return b>a?a-b+mod:a-b;
}
inline void swap(int&a,int&b){
	int c=a;a=b;b=c;
}
inline int pow(int a,int b=mod-2){
	int ans(1);for(;b;b>>=1,a=1ll*a*a%mod)if(b&1)ans=1ll*ans*a%mod;return ans;
}
inline void init(const int&n){
	const int&m=Getlen(n);int*now=buf;w[m]=now;now+=1<<m;
	w[m][0]=1;w[m][1]=pow(3,mod-1>>m+1);for(int i=2;i^1<<m;++i)w[m][i]=1ll*w[m][i-1]*w[m][1]%mod;
	for(int k=m-1;k>=0&&(w[k]=now,now+=1<<k);--k)IMP(1<<k,w[k][i]=w[k+1][i<<1]);
	ifac[0]=ifac[1]=1;for(int i=2;i<n;++i)ifac[i]=1ll*(mod-mod/i)*ifac[mod%i]%mod;
	for(int i=1;i<n;++i)ifac[i]=1ll*ifac[i-1]*ifac[i]%mod;
}
inline void DFT(int*f,const int&M){
	const int&n=1<<M;
	for(int len=n>>1,d=M-1;d>=0;--d,len>>=1)for(int k=0;k^n;k+=len<<1){
		int*W=w[d],*L=f+(k),*R=f+(k|len),x,y;IMP(len,(x=*L,y=*R)),*L++=Add(x,y),*R++=1ll**W++*Del(x,y)%mod;
	}
}
inline void IDFT(int*f,const int&M){
	const int&n=1<<M;
	for(int len=1,d=0;d^M;++d,len<<=1)for(int k=0;k^n;k+=len<<1){
		int*W=w[d],*L=f+(k),*R=f+(k|len),x,y;IMP(len,(x=*L,y=1ll**W++**R%mod)),*L++=Add(x,y),*R++=Del(x,y);
	}
	const int&k=pow(n);IMP(n,f[i]=1ll*f[i]*k%mod);for(int i=1;(i<<1)<n;++i)swap(f[i],f[n-i]);
}
inline void Getinv(int*f,const int&n){
	static int g[M];g[0]=f[0];for(int i=1;i<n;++i)g[i]=1ll*g[i-1]*f[i]%mod;
	int t,c=pow(g[n-1]);for(int i=n-1;i>=1;--i)t=f[i],f[i]=1ll*g[i-1]*c%mod,c=1ll*c*t%mod;f[0]=t;
}
inline void PT(int*f,int*g,const int&n,const int&m){
	static int F[M],G[M],H[M];H[0]=1;IMP(n,H[0]=1ll*H[0]*(m-i)%mod);IMP(n+n,G[i]=m-n+i);G[0]=1;Getinv(G,n+n);
	for(int i=1;i^n;++i)H[i]=1ll*(m+i)*G[i]%mod*H[i-1]%mod;
	IMP(n,F[i]=1ll*ifac[i]*(n-i-1&1?mod-ifac[n-i-1]:ifac[n-i-1])%mod*f[i]%mod);
	const int&len=Getlen(n+n);DFT(F,len);DFT(G,len);IMP(1<<len,F[i]=1ll*F[i]*G[i]%mod);IDFT(F,len);
	IMP(n,g[i]=1ll*F[n+i]*H[i]%mod);IMP(1<<len,F[i]=G[i]=H[i]=0);
}
inline int GetAns(const int&n,const int&m){
	static int F1[M],G1[M],H1[M],F2[M],G2[M],H2[M];const int&B=sqrt(m),&len=Getlen(B+1)-2;
	F1[0]=1;G1[0]=n;H1[0]=n;F1[1]=B+1;G1[1]=n-(B+1)+1;H1[1]=n-(B+1)+1;
	for(int i=len;i>=0;--i){
		const int&q=B>>i+1,&p=B>>i;
		PT(F1,F2,q+1,q+1);IMP(q,F1[q+i+1]=F2[i]);PT(F1,F2,q*2+1,1ll*q*pow(B)%mod);
		PT(G1,G2,q+1,q+1);IMP(q,G1[q+i+1]=G2[i]);PT(G1,G2,q*2+1,1ll*q*pow(B)%mod);
		PT(H1,H2,q+1,q+1);IMP(q,H1[q+i+1]=H2[i]);PT(H1,H2,q*2+1,1ll*q*pow(B)%mod);
		for(int i=0;i<=q*2;++i){
			H1[i]=(1ll*G1[i]*H2[i]+1ll*H1[i]*F2[i])%mod;
			G1[i]=1ll*G1[i]*G2[i]%mod;F1[i]=1ll*F1[i]*F2[i]%mod;
		}
		if(q*2+1==p){
			for(int i=0;i<=q*2;++i){
				H1[i]=(1ll*G1[i]*(n-(i*B+p)+1)+1ll*H1[i]*(i*B+p))%mod;
				G1[i]=1ll*G1[i]*(n-(i*B+p)+1)%mod;F1[i]=1ll*F1[i]*(i*B+p)%mod;
			}
			H1[p]=0;G1[p]=F1[p]=1;
			for(int i=1;i<=p;++i){
				H1[p]=(1ll*G1[p]*(n-(p*B+i)+1)+1ll*H1[p]*(p*B+i))%mod;
				G1[p]=1ll*G1[p]*(n-(p*B+i)+1)%mod;F1[p]=1ll*F1[p]*(p*B+i)%mod;
			}
		}
	}
	int F(1),G(1),H(1);IMP(B,(H=(1ll*G*H1[i]+1ll*H*F1[i])%mod,G=1ll*G*G1[i]%mod,F=1ll*F*F1[i]%mod));
	for(int i=B*B+1;i<=m;++i)H=(1ll*G*(n-i+1)+1ll*H*i)%mod,G=1ll*G*(n-i+1)%mod,F=1ll*F*i%mod;
	return 1ll*H*pow(F)%mod;
}
signed main(){
	int T,N,M;scanf("%d",&T);init(1<<17);while(T--)scanf("%d%d",&N,&M),printf("%d\n",GetAns(N,M));
}
posted @ 2022-07-01 08:54  Prean  阅读(34)  评论(0编辑  收藏  举报
var canShowAdsense=function(){return !!0};