[模板]多项式EX

模板代码

新增快速沃尔什变换与其逆变换。

即代码中 \(DWT\)\(IDWT\) 的部分。

#include<cstdio>
#include<algorithm>
#include<vector>
using namespace std;

#define NDEBUG
#include<cassert>

namespace Elaina{
    #define rep(i, l, r) for(int i=(l), i##_end_=(r); i<=i##_end_; ++i)
    #define drep(i, l, r) for(int i=(l), i##_end_=(r); i>=i##_end_; --i)
    #define fi first
    #define se second
    #define mp(a, b) make_pair(a, b)
    #define Endl putchar('\n')
    #define mmset(a, b) memset(a, b, sizeof a)
    // #define int long long
    typedef long long ll;
    typedef unsigned long long ull;
    typedef unsigned int uint;
    typedef pair<int, int> pii;
    typedef pair<ll, ll> pll;
    template<class T>inline T fab(T x){ return x<0? -x: x; }
    template<class T>inline void getmin(T& x, const T rhs){ x=min(x, rhs); }
    template<class T>inline void getmax(T& x, const T rhs){ x=max(x, rhs); }
    template<class T>inline T readin(T x){
        x=0; int f=0; char c;
        while((c=getchar())<'0' || '9'<c) if(c=='-') f=1;
        for(x=(c^48); '0'<=(c=getchar()) && c<='9'; x=(x<<1)+(x<<3)+(c^48));
        return f? -x: x;
    }
    template<class T>inline void writc(T x, char s='\n'){
        static int fwri_sta[1005], fwri_ed=0;
        if(x<0) putchar('-'), x=-x;
        do fwri_sta[++fwri_ed]=x%10, x/=10; while(x);
        while(putchar(fwri_sta[fwri_ed--]^48), fwri_ed);
        putchar(s);
    }
}
using namespace Elaina;

const int maxn=1<<17;

namespace __poly{
    const int Mod=1004535809, __G=3, inv2=998244354>>1;
    int w[maxn+5], Inv[maxn+5];
    vector<int>ans;
    vector< vector<int> >p;
    template<class T>inline int qkpow(int a, T n){
    //快速幂
        int ans=1;
        for(; n; n>>=1, a=1ll*a*a%Mod)
            if(n&1) ans=1ll*ans*a%Mod;
        return ans;
    }
    inline void Init(){
    //解决初始化, 包括原根的次方以及 maxn 以内的逆元
        for(uint i=1; i<maxn; i<<=1){
            w[i]=1;
            int t=qkpow(__G, (Mod-1)/i/2);
            for(uint j=1; j<i; ++j) w[i+j]=w[i+j-1]*t;
        }
        Inv[0]=Inv[1]=1;
        for(int i=2; i<=maxn; ++i)
            Inv[i]=1ll*Inv[Mod%i]*(Mod-Mod/i)%Mod;
    }
    inline pii Mul(pii x, pii y, int f){
    //计算 a+bω 部分的 pzz 的乘法
        return mp((1ll*x.fi*y.fi%Mod+1ll*x.se*y.se%Mod*f%Mod)%Mod, (1ll*x.fi*y.se%Mod+1ll*x.se*y.fi%Mod)%Mod);
    }
    inline int Quadratic_residue(int a){
    //常数在模意义下的开根
        if(a<=1) return a;
        //使用欧拉定律判断有无解
        if(qkpow(a, (Mod-1)>>1)!=1) return -1;
        int x,f;
        //找到一个 x 使得 x^2 - a 不是二次剩余
        do x=(((ull)rand()<<15)^rand())%(a-1)+1;
        while(qkpow(f=x*x-a, (Mod-1)>>1)==1);
        //初始变量
        pii ans=mp(1, 0), t=mp(x, 1);
        //类似于快速幂
        for(uint i=(Mod+1)>>1; i; i>>=1, t=Mul(t, t, f))
            if(i&1) ans=Mul(ans, t, f);
        //返回较小根
        return min(ans.fi, Mod-ans.fi);
    }
    inline int Get(const int x){int n=1;while(n<=x)n<<=1;return n;}
    inline void ntt(vector<int>&f,const int n){
    //ntt 的标准写法
        static ull F[maxn+5];
        if((int)f.size()!=n) f.resize(n);//避免 RE
        for(int i=0,j=0;i<n;++i){
            //复制数组, 同时更改系数
            F[i]=f[j];
            for(int k=n>>1; (j^=k)<k; k>>=1);
        }
        for(int i=1; i<n; i<<=1) for(int j=0; j<n; j+=i<<1){
            int *W=w+i;
            ull *F0=F+j, *F1=F+j+i, t;
            //通过数组指针进行访问
            for(int k=j; k<j+i; ++k, ++W, ++F0, ++F1){
                t=(*F1)*(*W)%Mod;
				(*F1)=*F0+Mod-t, (*F0)+=t;
            }
        }
        for(int i=0; i<n; ++i) f[i]=F[i]%Mod; //将数组复制回去
    }
    inline void Intt(vector<int>&f,const int n){
    //逆 ntt 的运算, 和一般的写法有点不一样, 稍微研究一下
        f.resize(n), reverse(f.begin()+1, f.end());
        ntt(f, n);
        int n_Inv=qkpow(n, Mod-2);
        //要乘逆元并取模 (取模在 * 里面重载了)
        for(int i=0; i<n; ++i) f[i]=1ll*f[i]*n_Inv%Mod;
    }
    inline vector<int> operator +(const vector<int>&f,const vector<int>&g){
    //重载多项式的加法
        vector<int>ans=f;
        for(int i=0, siz=f.size(); i<siz; ++i)
            ans[i]=(ans[i]+g[i])%Mod;
        return ans;
    }
    inline vector<int> operator *(const vector<int>&f,const vector<int>&g){
        if((ull)f.size()*g.size()<=1000){
            //在这种情况下, 暴力算要更快一些...
            vector<int>ans;
            ans.resize(f.size()+g.size()-1);
            for(int i=0, sf=f.size(); i<sf; ++i)
                for(int j=0, sg=g.size(); j<sg; ++j)
                    ans[i+j]=(ans[i+j]+1ll*f[i]*g[j]%Mod)%Mod;
            return ans;
        }
        //保存数组
        static vector<int> F, G;
        F=f, G=g;
        int p=Get((int)f.size()+(int)g.size()-2);//为什么是 -2 ?
        ntt(F,p), ntt(G,p);
        for(int i=0; i<p; ++i) F[i]=1ll*F[i]*G[i]%Mod;
        Intt(F, p); F.resize((int)f.size()+(int)g.size()-1);
        return F;
    }
    vector<int>& polyinv(const vector<int>& f,int n=-1){
    //返回逆元多项式
        //默认为数组的大小
        if(n==-1) n=f.size();
        if(n==1){
            //如果到了最后一项, 则直接求逆元
            static vector<int> ans;
            return ans.clear(), ans.push_back(qkpow(f[0], Mod-2)), ans;
        }
        //剩下的是标准操作
        vector<int> &ans=polyinv(f,(n+1)>>1);
        vector<int> tmp(&f[0], &f[0]+n);
        int p=Get(n*2-2);
        ntt(tmp, p), ntt(ans, p);
        for(int i=0; i<p; ++i)
            ans[i]=1ll*ans[i]*(2+Mod-1ll*ans[i]*tmp[i]%Mod)%Mod;
        Intt(ans, p); ans.resize(n);
        return ans;
    }
    inline void polydiv(const vector<int>&a, const vector<int>&b, vector<int>&d, vector<int>&r){
    //多项式除法
        if(b.size()>a.size()) return d.clear(), (void)(r=a);
        vector<int>A=a, B=b, iB;
        int n=a.size(), m=b.size();
        reverse(A.begin(), A.end()), reverse(B.begin(), B.end());
        B.resize(n-m+1), iB=polyinv(B);
        d=A*iB;
        d.resize(n-m+1), reverse(d.begin(), d.end());
        r=b*d, r.resize(m-1);
        for(int i=0; i<m-1; ++i) r[i]=(a[i]+Mod-r[i])%Mod;
    }
    inline vector<int> Derivative(const vector<int>& a){
    //函数的导数
        vector<int>ans((int)a.size()-1);
        for(int i=1, sa=a.size(); i<sa; ++i)
            ans[i-1]=1ll*a[i]*i%Mod;
        return ans;
    }
    inline vector<int> Integral(const vector<int>& a){
    //计算微分
        vector<int>ans(a.size()+1);
        for(int i=0, sa=a.size(); i<sa; ++i)
            ans[i+1]=1ll*a[i]*Inv[i+1]%Mod;
        return ans;
    }
    inline vector<int>polyln(const vector<int>& f){
    //自然对数
        vector<int>ans=Derivative(f)*polyinv(f);
        ans.resize((int)f.size()-1);
        return Integral(ans);
    }
    vector<int> polyexp(const vector<int>& f,int n=-1){
        if(n==-1) n=f.size();
        if(n==1) return {1};
        vector<int> ans=polyexp(f, (n+1)>>1), tmp;
        ans.resize(n),tmp=polyln(ans);
        for(auto it=tmp.begin(); it!=tmp.end(); ++it)
            (*it)=-(*it);//此处 C++11 的写法在 luoguOJ 上的评测结果不同...
        // for(Z &i:tmp)i=-i;
        ++tmp[0];
        ans=ans*(tmp+f); ans.resize(n);
        return ans;
    }
    vector<int> polysqrt(const vector<int>& f, int n=-1){
        if(n==-1)n=f.size();
        vector<int>ans;
        if(n==1) return ans.push_back(Quadratic_residue(f[0])), ans;
        ans=polysqrt(f,(n+1)>>1);
        vector<int>tmp(&f[0], &f[0]+n);
        ans.resize(n), ans=ans+tmp*polyinv(ans);
        for(auto it=ans.begin(); it!=ans.end(); ++it)
            *it=((*it)&1)? (((*it)+Mod)%Mod)>>1: (*it)>>1;
        return ans;
    }
    inline vector<int> polyqkpow(const vector<int>& f, const int k){
        vector<int>ans=polyln(f);
        for(auto it=ans.begin(); it!=ans.end(); ++it)
            (*it)=(*it)*k;//此处也有 C11 的问题
        // for(Z &i:ans)i=i*k;
        return polyexp(ans);
    }
    void Evaluate_Interpolate_Init(int l, int r, int t, const vector<int> &a){
		if(l==r)return p[t].clear(), p[t].push_back((Mod-a[l])%Mod), p[t].push_back(1);
		int mid=(l+r)/2, k=t<<1;
		Evaluate_Interpolate_Init(l,mid,k,a),Evaluate_Interpolate_Init(mid+1,r,k|1,a);
		p[t]=p[k]*p[k|1];
	}
	void Evaluate(int l,int r,int t,const vector<int> &f,const vector<int> &a){
		if(r-l+1<=512){
			for(int i=l;i<=r;++i){
				int x=0, a1=a[i], a2=a[i]*a[i], a3=a[i]*a2, a4=a[i]*a3, a5=a[i]*a4, a6=a[i]*a5, a7=a[i]*a6, a8=a[i]*a7;
				int j=f.size();
				while(j>=8) x=(0ll+1ll*x*a8%Mod+1ll*f[j-1]*a7%Mod+1ll*f[j-2]*a6%Mod+1ll*f[j-3]*a5%Mod+1ll*f[j-4]*a4%Mod+1ll*f[j-5]*a3%Mod+1ll*f[j-6]*a2%Mod+1ll*f[j-7]*a1%Mod+1ll*f[j-8]%Mod)%Mod, j-=8;
				while(j--) x=(1ll*x*a[i]%Mod+f[j])%Mod;
				ans.push_back(x);
			}
			return;
		}
		vector<int>tmp;
		polydiv(f, p[t], tmp, tmp);
		Evaluate(l, (l+r)/2, t<<1, tmp, a),Evaluate((l+r)/2+1, r, t<<1|1, tmp, a);
	}
	inline vector<int> Evaluate(const vector<int>&f, const vector<int>&a, int flag=-1){
		if(flag==-1) p.resize(a.size()<<2), Evaluate_Interpolate_Init(0, a.size()-1, 1, a);
		return ans.clear(), Evaluate(0, a.size()-1, 1, f, a), ans;
	}
	vector<int> Interpolate(int l, int r, int t, const vector<int>&x, const vector<int>&f){
		if(l==r) return {f[l]};
		int mid=(l+r)/2, k=t<<1;
		return Interpolate(l, mid, k, x, f)*p[k|1]+Interpolate(mid+1, r, k|1, x, f)*p[k];
	}
	inline vector<int> Interpolate(const vector<int>&x,const vector<int>&y){
		int n=x.size();
		p.resize(n<<2),Evaluate_Interpolate_Init(0, n-1, 1, x);
		vector<int> f=Evaluate(Derivative(p[1]), x, 0);
		for(int i=0; i<n; ++i) f[i]=y[i]*qkpow(f[i], Mod-2);
		return Interpolate(0, n-1, 1, x, f);
	}
    inline void DWT_or(vector<int>&f,const int opt){
        //if opt==-1, DWT_or will be IDWT_or
        int N=Get(f.size()); f.resize(N);
        for(int t=2; t<=N; t<<=1)
            for(int i=0, p=t>>1; i<N; i+=t) for(int j=i; j<i+p; ++j)
                f[j+p]=(0ll+f[j+p]+Mod+f[j]*opt)%Mod;
    }
    inline void DWT_and(vector<int>&f,const int opt){
        int N=Get(f.size()); f.resize(N);
        for(int t=2; t<=N; t<<=1)
            for(int i=0, p=t>>1; i<N; i+=t) for(int j=i; j<i+p; ++j)
                f[j]=(0ll+f[j]+Mod+f[j+p]*opt)%Mod;
    }
    inline void DWT_xor(vector<int>&f, const int opt){
        int N=Get(f.size()-1); //pay attention to this -1
        f.resize(N); int x, y;
        for(int t=2; t<=N; t<<=1)
            for(int i=0, p=t>>1; i<N; i+=t) for(int j=i; j<i+p; ++j){
                x=f[j], y=f[j+p];
                f[j]=(0ll+Mod+x+y)%Mod, f[j+p]=(0ll+Mod+x-y)%Mod;
                if(opt==-1) f[j]=1ll*f[j]*inv2%Mod, f[j+p]=1ll*f[j+p]*inv2%Mod;
            }
    }
};

using namespace __poly;

signed main(){
    
    return 0;
}
posted @ 2020-06-08 20:41  Arextre  阅读(203)  评论(0编辑  收藏  举报