多项式板子

内涵多项式乘法,多项式求逆,多项式求对数,多项式求exp,多项式求快速幂

常数较大,有空再优化

#include<iostream>
#include<cstdio>
#include<cmath>
using namespace std;
int limit,l,r[10000010];
const int mod=998244353,P=3,invP=(mod+1)/P;
int q_pow(int x,int y){
	if(!y) return 1;
	int z=q_pow(x,y>>1);
	if(y&1) return 1ll*z*z%mod*x%mod;
	else return 1ll*z*z%mod;
} 
void NTT(int *a,int k){
	for(int i=0;i<limit;i++) if(i<r[i]) swap(a[i],a[r[i]]);
	for(int mid=1;mid<limit;mid<<=1){
		int dw=q_pow(k==1?P:invP,(mod-1)/(mid<<1));
		for(int j=0;j<limit;j+=(mid<<1)){
			int w=1;
			for(int k=j;k<j+mid;k++,w=1ll*w*dw%mod){
				int x=a[k],y=1ll*w*a[k+mid]%mod;
				a[k]=(x+y)%mod;
				a[k+mid]=(x-y+mod)%mod;
			}
		} 
	}
}
int* polymul(int n,int m,int *c,int *d){	
	limit=1;
	l=0;
	while(limit<=n+m) limit<<=1,l++;
	int *a=(int* )calloc(2*limit+10,sizeof(int));
	int *b=(int* )calloc(2*limit+10,sizeof(int));
	for(int i=0;i<limit;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
	for(int i=0;i<=n;i++) a[i]=c[i];
	for(int i=0;i<=m;i++) b[i]=d[i];
	NTT(a,1);
	NTT(b,1);
	for(int i=0;i<=limit;i++) a[i]=1ll*a[i]*b[i]%mod;
	NTT(a,-1);
	int inv_limit=q_pow(limit,mod-2);
	for(int i=0;i<=n+m;i++) a[i]=1ll*a[i]*inv_limit%mod;
	return a;
}
int* polymul3(int n,int m,int *c,int *d){	
	limit=1;
	l=0;
	while(limit<=n+n) limit<<=1,l++;
	int *a=(int* )calloc(2*limit+10,sizeof(int));
	int *b=(int* )calloc(2*limit+10,sizeof(int));
	for(int i=0;i<limit;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
	for(int i=0;i<=n;i++) a[i]=c[i];
	for(int i=0;i<=m;i++) b[i]=d[i];
	NTT(a,1);
	NTT(b,1);
	for(int i=0;i<=limit;i++) a[i]=1ll*a[i]*b[i]%mod*b[i]%mod;
	NTT(a,-1);
	int inv_limit=q_pow(limit,mod-2);
	for(int i=0;i<=n+m;i++) a[i]=1ll*a[i]*inv_limit%mod;
	return a;
}
int *polyinv(int n,int *a){
	int *b=(int* )calloc(4*n+10,sizeof(int));	
	b[0]=q_pow(a[0],mod-2);
	for(int i=2;;i*=2){
		limit=1;
		l=0;
		while(limit<=2*i) limit<<=1,l++;
		int *d=(int* )calloc(limit+10,sizeof(int));
		for(int j=0;j<i;j++) d[j]=a[j];
		for(int j=i;j<=limit;j++) d[j]=0;
		int *c=polymul3(i,i,d,b);
		for(int j=0;j<=limit;j++) b[j]=((b[j]+b[j])%mod-c[j]+mod)%mod;
		if(i>=n) break;
	}
	return b;
}
int *polyln(int n,int *a){
	int *b=(int* )calloc(2*n+10,sizeof(int));
	int *e=(int* )calloc(2*n+10,sizeof(int));
	for(int i=0;i<=n;i++) b[i]=a[i];
	int *c=polyinv(n,b);
	for(int i=1;i<=n;i++) e[i-1]=1ll*i*a[i]%mod;
	int *d=polymul(n,n,e,c);
	int *ans=(int* )calloc(2*n+10,sizeof(int));
	for(int i=0;i<=n;i++) ans[i+1]=1ll*q_pow(i+1,mod-2)*d[i]%mod;
	return ans;
} 
int *polyexp(int n,int *a){
	int *b=(int* )calloc(4*n+10,sizeof(int));
	b[0]=1;
	for(int i=2;;i*=2){
		int *c=polyln(i,b);
		for(int j=0;j<=i;j++) c[j]=(mod-c[j]+a[j])%mod;
		c[0]++;
		b=polymul(i,i,b,c);
		if(i>=n) break;
	}
	return b;
}
int *poly_q_pow(int n,int m,int *a,int *k){
	if(a[0]==0){
		int kk=0;
		for(int i=6;i>=0;i--) kk=kk*10+k[i];
		if(kk>n || m>6){
			int *b=(int* )calloc(4*n+10,sizeof(int));
			return b;
		}
	}
	int modk=0,modk2=0,x=0;
	for(int i=m-1;i>=0;i--) modk=(1ll*modk*10+k[i])%mod,modk2=(1ll*modk2*10+k[i])%(mod-1);
	while(a[x]==0 && x<n) x++;
	int *b=(int* )calloc(4*n+10,sizeof(int));
	if(x==n) return b;;
	for(int i=x;i<n;i++) b[i-x]=a[i];
	int y=b[0];
	int z=q_pow(y,mod-2);
	for(int i=0;i<n;i++) b[i]=1ll*b[i]*z%mod;
	
	b=polyln(n,b);
	for(int i=0;i<n;i++) b[i]=1ll*b[i]*modk%mod;
	b=polyexp(n,b);
	y=q_pow(y,modk2);
	for(int i=0;i<n;i++) b[i]=1ll*b[i]*y%mod;
	if(a[0]==0){
		int kk=0;
		for(int i=6;i>=0;i--) kk=kk*10+k[i];
		for(int i=n-1;i>=0;i--) b[min(1ll*n,i+1ll*kk*x)]=b[i],b[i]=0;
	}
	return b;
}
int main(){
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);

	return 0;
}
posted @ 2023-12-22 18:10  蒻蒟cdx  阅读(20)  评论(1编辑  收藏  举报