多项式小全家桶

比较安全的模板,传入的数组 \(g\) 有初值也没有问题,传入的 \(g\)\(f\) 可以一样。且求解过程中不会对传入的 \(f\) 修改

没事就清空,清空就没事!

#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
const int N=4e5+5;
const int inv3=(mod+1)/3;

template<typename A>int mul(A x){return x;}
template<typename A,typename...B>int mul(A x,B...args){return 1ll*x*mul(args...)%mod;}
int add(int a,int b){return a>=mod-b?a-mod+b:a+b;}
int del(int a,int b){return a>=b?a-b:a+mod-b;}
int ksm(int a,int b){
	int res=1;
	while(b) {
		if(b&1) res=mul(res,a);
		a=mul(a,a),b>>=1;
	}
	return res;
} 

int rev[N],iv[N];

void getrev(int len) {
	for(int i=0;i<len;++i)
		rev[i]=(rev[i>>1]>>1)|((i&1)?len/2:0);
}
void ntt(int *f,int op,int len) {
	for(int i=0;i<len;++i)
		if(i<rev[i]) swap(f[i],f[rev[i]]);
	for(int i=2;i<=len;i<<=1) {
		int base=ksm(op==1?3:inv3,(mod-1)/i);
		for(int j=0,p=i>>1;j<len;j+=i)
			for(int k=0,pw=1;k<p;++k,pw=mul(pw,base)) {
				int x=f[j+k],y=mul(f[j+k+p],pw);
				f[j+k]=add(x,y),f[j+k+p]=del(x,y);
			}
	}
	if(op==-1)
		for(int i=0,inv=ksm(len,mod-2);i<len;++i)
			f[i]=mul(f[i],inv);
}

int F[N],G[N];
void mul(int *f,int *g,int *h,int n) {
	int len=1,bit=0;
	while(len<(n<<1)) len<<=1,++bit;
	getrev(len);
	
	memcpy(F,f,sizeof(int)*n);
	memcpy(G,g,sizeof(int)*n);
	ntt(F,1,len),ntt(G,1,len);
	for(int i=0;i<len;++i) F[i]=mul(F[i],G[i]);
	ntt(F,-1,len);
	
	for(int i=0;i<n;++i) h[i]=F[i],F[i]=G[i]=0;
	for(int i=n;i<len;++i) h[i]=F[i]=G[i]=0;
}
int Ft[N],Gt[N];
void mult(int *f,int *g,int *h,int n,int m) {
	memcpy(Ft,f,sizeof(int)*n);
	memcpy(Gt,g,sizeof(int)*m);
	reverse(Gt,Gt+m);
	mul(Ft,Gt,Gt,n+m);
	for(int i=0;i<n;++i) h[i]=Gt[i+m-1];
	memset(Ft,0,sizeof(int)*n);
	memset(Gt,0,sizeof(int)*(n+m));
} 

int fi[N],gi[N];
void inv(int *f,int *g,int n) {
	int cur=1;
	gi[0]=ksm(f[0],mod-2);
	for(;cur<n;cur<<=1) {
		int len=cur<<2;
		getrev(len);
		
		for(int i=0;i<(cur<<1);++i)
			fi[i]=f[i];
		
		ntt(fi,1,len),ntt(gi,1,len);
		for(int i=0;i<len;++i)
			gi[i]=mul(gi[i],(2-mul(gi[i],fi[i])+mod)%mod);
		ntt(gi,-1,len);
		
		for(int i=cur<<1;i<len;++i) gi[i]=fi[i]=0;
	} 
	
	memset(g,0,sizeof(g));
	for(int i=0;i<n;++i) g[i]=gi[i],gi[i]=fi[i]=0;
	for(int i=n;i<cur;++i) gi[i]=fi[i]=0;
}

int fr[N],gr[N];
void div(int *f,int *g,int *q,int *r,int n,int m) {
	for(int i=0;i<n;++i) fr[i]=f[n-1-i];
	for(int i=0;i<m;++i) gr[i]=g[m-1-i];
	
	inv(gr,gr,n-m+1);
	mul(fr,gr,q,n-m+1);
	reverse(q,q+n-m+1);
	mul(g,q,r,m-1);
	
	for(int i=0;i<m-1;++i) r[i]=del(f[i],r[i]);
}

void deriv(int *f,int *g,int n) {
	for(int i=0;i<n;++i) g[i]=mul(f[i+1],i+1); 
	g[n-1]=0;
}
void integ(int *f,int *g,int n) {
	for(int i=n-2;i>=0;--i) g[i+1]=mul(f[i],iv[i+1]);
	g[0]=0;
}

int fl[N],gl[N];
void ln(int *f,int *g,int n) {
	deriv(f,fl,n);
	inv(f,gl,n);
	mul(fl,gl,g,n);
	integ(g,g,n);
	for(int i=0;i<n;++i) fl[i]=gl[i]=0;
}
int fe[N],ge[N];
void exp(int *f,int *g,int n) {
	int cur=1;
	ge[0]=1;
	for(;cur<n;cur<<=1) {
		ln(ge,fe,cur<<1);
		
		for(int i=0;i<(cur<<1);++i)
			fe[i]=del(f[i],fe[i]);
		fe[0]++;
	
		mul(ge,fe,ge,cur<<1);
	}
	memset(g,0,sizeof(g));
	for(int i=0;i<n;++i) g[i]=ge[i],ge[i]=fe[i]=0;
	for(int i=n;i<cur;++i) ge[i]=fe[i]=0; 
}
int fk[N];
void ksm(int *f,int *g,int n,int k) {
	ln(f,fk,n);
	for(int i=0;i<n;++i)
		fk[i]=mul(fk[i],k);
	exp(fk,g,n);
	for(int i=0;i<n;++i) fk[i]=0;
}

#define ls rt<<1
#define rs rt<<1|1
#define mid ((l+r)>>1)
int pool[N*40],*ptr=pool,*Q[N];
void calcq(int l,int r,int rt,int *a) {
	int n=r-l+2;
	int len=1<<(int)ceil(log2(n)+0.1);
	Q[rt]=ptr,ptr+=len*2+2; 
	if(l==r) return Q[rt][0]=1,Q[rt][1]=a[l]?mod-a[l]:0,void();
	calcq(l,mid,ls,a),calcq(mid+1,r,rs,a);
	mul(Q[ls],Q[rs],Q[rt],n);
}
int h[20][N];
void solvemult(int l,int r,int rt,int *f,int *g,int dep) {
	int n=r-l+2;
	if(l==r) return g[l]=f[0],void();

	mult(f,Q[rs],h[dep],n,r-mid+1);
	for(int i=mid-l+2;i<=n;++i) h[dep][i]=0;
	solvemult(l,mid,ls,h[dep],g,dep+1);
	memset(h[dep],0,sizeof(int)*(mid-l+2));
	
	mult(f,Q[ls],h[dep],n,mid-l+2);
	for(int i=r-mid+1;i<=n;++i) h[dep][i]=0;
	solvemult(mid+1,r,rs,h[dep],g,dep+1);
	memset(h[dep],0,sizeof(int)*(r-mid+1));
}
int fm[N];
void calcmultpoints(int *f,int *a,int *b,int n) {
	memcpy(fm,f,sizeof(int)*n);
	calcq(0,n-1,1,a);
	inv(Q[1],Q[1],n+1);
	mult(fm,Q[1],fm,n+1,n+1);
	solvemult(0,n-1,1,fm,b,0);
	memset(fm,0,sizeof(fm));
}

int f[N];
int n;

signed main() {
	ios::sync_with_stdio(false);
	cin.tie(0),cout.tie(0);
	iv[1]=1;
	for(int i=2;i<N;++i) iv[i]=mul(mod-mod/i,iv[mod%i]);


	
	return 0;
}

posted @ 2023-08-28 16:51  _Famiglistimo  阅读(26)  评论(0编辑  收藏  举报