【系列】【模板】 多项式

【模板】 多项式乘法

学习一波NTT 放弃FFT

 1 #include<iostream>
 2 #include<cstdio>
 3 #include<cstdlib>
 4 #include<cmath>
 5 #include<algorithm>
 6 #include<cstring>
 7 #include<vector>
 8 #include<queue>
 9 #include<map>
10 #define rep(i,s,t) for(register int i=(s);i<=(t);++i)
11 #define dwn(i,s,t) for(register int i=(s);i>=(t);--i)
12 #define ren for(register int i=fst[x];i;i=nxt[i])
13 #define Fill(x,t) memset(x,t,sizeof(x))
14 #define ll long long
15 #define inf 2139062143
16 #define MOD 998244353
17 #define MAXN 2001000
18 using namespace std;
19 inline int read()
20 {
21     int x=0,f=1;char ch=getchar();
22     while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
23     while(isdigit(ch)) {x=x*10+ch-'0';ch=getchar();}
24     return x*f;
25 }
26 int n,m,lg,lmt;
27 ll A[MAXN<<2],B[MAXN<<2],rev[MAXN<<2];
28 ll q_pow(ll bas,ll t,ll res=1)
29 {
30     for(;t;t>>=1,(bas*=bas)%=MOD)
31         if(t&1) (res*=bas)%=MOD;return res;
32 }
33 void ntt(ll *a,int n,int f)
34 {
35     rep(i,0,n-1) if(i<rev[i])swap(a[i],a[rev[i]]);
36     for(int i=1;i<n;i<<=1)
37     {
38         ll wn=q_pow(3,(MOD-1)/(i<<1))%MOD;
39         if(f==-1)wn=q_pow(wn,MOD-2);
40         for(int j=0;j<n;j+=i<<1)
41         {
42             ll w=1,x,y;
43             for(int k=0;k<i;k++,w=wn*w%MOD)
44                 x=a[k+j],y=((ll)a[k+j+i]*w)%MOD,a[j+k]=(x+y)%MOD,a[j+k+i]=(x-y+MOD)%MOD;
45         }
46     }
47     if(f==1) return ;int nv=q_pow(n,MOD-2);
48     for(int i=0;i<n;i++) a[i]=a[i]*nv%MOD;
49 }
50 int main()
51 {
52     n=read()+1,m=read()+1;rep(i,0,n-1) A[i]=read();rep(i,0,m-1) B[i]=read();
53     lg=ceil(log2(n+m)),lmt=1<<lg;
54     rep(i,0,lmt-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1));
55     ntt(A,lmt,1);ntt(B,lmt,1);rep(i,0,lmt-1) (A[i]*=B[i])%=MOD;
56     ntt(A,lmt,-1);rep(i,0,n+m-2) printf("%lld ",A[i]);
57 }
View Code

 

(FFT板子)

 1 #include<iostream>
 2 #include<cstdio>
 3 #include<cstdlib>
 4 #include<cmath>
 5 #include<algorithm>
 6 #include<cstring>
 7 #include<vector>
 8 #include<queue>
 9 #include<complex>
10 #include<map>
11 #define rep(i,s,t) for(register int i=(s);i<=(t);++i)
12 #define dwn(i,s,t) for(register int i=(s);i>=(t);--i)
13 #define ren for(register int i=fst[x];i;i=nxt[i])
14 #define Fill(x,t) memset(x,t,sizeof(x))
15 #define ll long long
16 #define Cd complex<double>
17 #define inf 2139062143
18 #define MOD 998244353
19 #define MAXN 1001000
20 using namespace std;
21 inline int read()
22 {
23     int x=0,f=1;char ch=getchar();
24     while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
25     while(isdigit(ch)) {x=x*10+ch-'0';ch=getchar();}
26     return x*f;
27 }
28 int n,m;
29 const double pi=acos(-1);
30 int rev[MAXN<<2];
31 Cd A[MAXN<<2],B[MAXN<<2];
32 void fft(Cd *a,int n,int f)
33 {
34     rep(i,0,n-1) if(i<rev[i])swap(a[i],a[rev[i]]);
35     for(int i=1;i<n;i<<=1)
36     {
37         Cd wn(cos(pi/i),f*sin(pi/i));
38         for(int j=0;j<n;j+=i<<1)
39         {
40             Cd w(1,0),x,y;
41             for(int k=0;k<i;k++,w*=wn)
42                 x=a[k+j],y=a[k+j+i]*w,a[j+k]=x+y,a[j+k+i]=x-y;
43         }
44     }
45     if(f==1) return ;rep(i,0,n-1) a[i]/=n;
46 }
47 void solve(Cd *a,Cd *b,int sum)
48 {
49     int lg=ceil(log2(sum)),lmt=1<<lg;
50     rep(i,0,lmt-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1));
51     fft(a,lmt,1);fft(b,lmt,1);rep(i,0,lmt-1) a[i]*=b[i];
52     fft(a,lmt,-1);
53 }
54 int main()
55 {
56     n=read()+1,m=read()+1;rep(i,0,n-1) A[i]=read();rep(i,0,m-1) B[i]=read();
57     solve(A,B,n+m);rep(i,0,n+m-2) printf("%lld ",(ll)(A[i].real()+0.5));
58 }
View Code

 

【模板】多项式求逆

背诵关键代码(大概思路是倍增的求出每部分的逆元)

 1 #include<iostream>
 2 #include<cstdio>
 3 #include<cstdlib>
 4 #include<cmath>
 5 #include<algorithm>
 6 #include<cstring>
 7 #include<vector>
 8 #include<queue>
 9 #include<map>
10 #define rep(i,s,t) for(register int i=(s);i<=(t);++i)
11 #define dwn(i,s,t) for(register int i=(s);i>=(t);--i)
12 #define ren for(register int i=fst[x];i;i=nxt[i])
13 #define Fill(x,t) memset(x,t,sizeof(x))
14 #define ll long long
15 #define inf 2139062143
16 #define MOD 998244353
17 #define MAXN 2001000
18 using namespace std;
19 inline int read()
20 {
21     int x=0,f=1;char ch=getchar();
22     while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
23     while(isdigit(ch)) {x=x*10+ch-'0';ch=getchar();}
24     return x*f;
25 }
26 int n,m;
27 ll A[MAXN<<2],B[MAXN<<2],F[MAXN<<2],rev[MAXN<<2];
28 ll q_pow(ll bas,ll t,ll res=1)
29 {
30     for(;t;t>>=1,(bas*=bas)%=MOD)
31         if(t&1) (res*=bas)%=MOD;return res;
32 }
33 void ntt(ll *a,int n,int f)
34 {
35     rep(i,0,n-1) if(i<rev[i])swap(a[i],a[rev[i]]);
36     for(int i=1;i<n;i<<=1)
37     {
38         ll wn=q_pow(3,(MOD-1)/(i<<1))%MOD;
39         if(f==-1)wn=q_pow(wn,MOD-2);
40         for(int j=0;j<n;j+=i<<1)
41         {
42             ll w=1,x,y;
43             for(int k=0;k<i;k++,w=wn*w%MOD)
44                 x=a[k+j],y=((ll)a[k+j+i]*w)%MOD,a[j+k]=(x+y)%MOD,a[j+k+i]=(x-y+MOD)%MOD;
45         }
46     }
47     if(f==1) return ;int nv=q_pow(n,MOD-2);
48     for(int i=0;i<n;i++) a[i]=a[i]*nv%MOD;
49 }
50 void solve(ll *a,ll *b,int sum)
51 {
52     int lg=ceil(log2(sum)),lmt=1<<lg;
53     rep(i,0,lmt-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1));
54     ntt(a,lmt,1);ntt(b,lmt,1);rep(i,0,lmt-1) a[i]=(((2LL-a[i]*b[i])%MOD+MOD)*a[i])%MOD;
55     ntt(a,lmt,-1);
56 }
57 int main()
58 {
59     n=read();rep(i,0,n-1) A[i]=read();
60     int lg=ceil(log2(n)),lmt=1<<lg;B[0]=q_pow(A[0],MOD-2);
61     for(int t=2;t<=lmt;t<<=1)
62     {
63         rep(i,0,t-1) F[i]=A[i];solve(B,F,t<<1);
64         rep(i,t,(t<<1)-1) B[i]=0;
65     }
66     rep(i,0,n-1) printf("%lld ",B[i]);
67 }
View Code

 

【模板】分治FFT

大概类似于cdq 计算左边对右边的贡献

 1 #include<iostream>
 2 #include<cstdio>
 3 #include<cstring>
 4 #include<cstdlib>
 5 #include<cmath>
 6 #include<algorithm>
 7 #include<queue>
 8 #include<vector>
 9 #include<map>
10 #include<set>
11 #define ll long long
12 #define inf 2139062143
13 #define MAXN 400100
14 #define MOD 998244353
15 #define rep(i,s,t) for(register int i=(s),i##__end=(t);i<=i##__end;++i)
16 #define dwn(i,s,t) for(register int i=(s),i##__end=(t);i>=i##__end;--i)
17 #define ren(x) for(register int i=fst[x];i;i=nxt[i])
18 #define pb(i,x) vec[i].push_back(x)
19 #define pls(a,b) (a+b)%MOD
20 #define mns(a,b) (a-b+MOD)%MOD
21 #define mul(a,b) (1LL*(a)*(b))%MOD
22 using namespace std;
23 inline int read()
24 {
25     int x=0,f=1;char ch=getchar();
26     while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
27     while(isdigit(ch)) {x=x*10+ch-'0';ch=getchar();}
28     return x*f;
29 }
30 int n,g[MAXN],f[MAXN],rev[MAXN],pw[MAXN],l2[MAXN];
31 int A[MAXN],B[MAXN];
32 int q_pow(int bas,int t,int res=1)
33 {
34     for(;t;bas=mul(bas,bas),t>>=1)
35         if(t&1) res=mul(res,bas);return res;
36 }
37 void ntt(int *a,int n,int f)
38 {
39     rep(i,0,n-1) if(i<rev[i]) swap(a[i],a[rev[i]]);
40     for(int i=1;i<n;i<<=1)
41     {
42         int wn=pw[i<<1];if(f==-1) wn=q_pow(wn,MOD-2);
43         for(int j=0;j<n;j+=i<<1)
44         {
45             int w=1,x,y;
46             for(int k=0;k<i;k++,w=mul(w,wn))
47                 x=a[k+j],y=mul(a[i+j+k],w),a[j+k]=pls(x,y),a[i+j+k]=mns(x,y);
48         }
49     }
50     if(f==1) return ;int nv=q_pow(n,MOD-2);
51     rep(i,0,n-1) a[i]=mul(a[i],nv);
52 }
53 void solve(int *a,int *b,int lmt)
54 {
55     ntt(a,lmt,1);ntt(b,lmt,1);rep(i,0,lmt-1) a[i]=mul(a[i],b[i]);
56     ntt(a,lmt,-1);
57 }
58 void cdq(int l,int r)
59 {
60     if(l==r) return ;int mid=l+r>>1,lmt=r-l+1;cdq(l,mid);
61     int t=l2[lmt]+1;lmt=1<<t;
62     rep(i,0,lmt-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<(t-1)),A[i]=B[i]=0;
63     rep(i,l,mid) A[i-l]=f[i];rep(i,1,r-l) B[i-1]=g[i];
64     solve(A,B,lmt);rep(i,mid+1,r) f[i]=pls(f[i],A[i-l-1]);
65     cdq(mid+1,r);
66 }
67 int main()
68 {
69     n=read();rep(i,1,n-1) g[i]=read();rep(i,2,n<<1) l2[i]=l2[i>>1]+1;
70     rep(i,2,n<<1) pw[i]=q_pow(3,(MOD-1)/i);
71     f[0]=1;cdq(0,n-1);rep(i,0,n-1) printf("%d ",f[i]);
72 }
View Code

 

bzoj 3527 力

题目大意:

给出$q_i$已知$F_j= \sum_{i<j} \frac {q_i \times q_j}{(i-j)^2}-\sum_{i>j} \frac{q_i \times q_j}{(i-j)^2}$ 

令$E_i= \frac{F_i}{q_i}$ 求所有的$E_i$

思路:

由于$q_i$被除掉了 因此对于两部分分别暴力卷积

 1 #include<iostream>
 2 #include<cstdio>
 3 #include<cstdlib>
 4 #include<cmath>
 5 #include<algorithm>
 6 #include<cstring>
 7 #include<vector>
 8 #include<queue>
 9 #include<complex>
10 #include<map>
11 #define rep(i,s,t) for(register int i=(s);i<=(t);++i)
12 #define dwn(i,s,t) for(register int i=(s);i>=(t);--i)
13 #define ren for(register int i=fst[x];i;i=nxt[i])
14 #define Fill(x,t) memset(x,t,sizeof(x))
15 #define ll long long
16 #define Cd complex<double>
17 #define inf 2139062143
18 #define MOD 998244353
19 #define MAXN 100100
20 using namespace std;
21 inline int read()
22 {
23     int x=0,f=1;char ch=getchar();
24     while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
25     while(isdigit(ch)) {x=x*10+ch-'0';ch=getchar();}
26     return x*f;
27 }
28 int n,m;
29 const double pi=acos(-1);
30 int rev[MAXN<<2];
31 Cd A1[MAXN<<2],A2[MAXN<<2],B[MAXN<<2],C[MAXN<<2];
32 void fft(Cd *a,int n,int f)
33 {
34     rep(i,0,n-1) if(i<rev[i])swap(a[i],a[rev[i]]);
35     for(int i=1;i<n;i<<=1)
36     {
37         Cd wn(cos(pi/i),f*sin(pi/i));
38         for(int j=0;j<n;j+=i<<1)
39         {
40             Cd w(1,0),x,y;
41             for(int k=0;k<i;k++,w*=wn)
42                 x=a[k+j],y=a[k+j+i]*w,a[j+k]=x+y,a[j+k+i]=x-y;
43         }
44     }
45     if(f==1) return ;rep(i,0,n-1) a[i]/=n;
46 }
47 void solve(Cd *a,Cd *b,int sum)
48 {
49     int lg=ceil(log2(sum)),lmt=1<<lg;
50     rep(i,0,lmt-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1));
51     fft(a,lmt,1);fft(b,lmt,1);rep(i,0,lmt-1) a[i]*=b[i];
52     fft(a,lmt,-1);
53 }
54 int main()
55 {
56     n=read();rep(i,1,n) A1[i]=A2[i]=(double)(1.0/i/i);
57     rep(i,0,n-1) scanf("%lf",&B[i]),C[n-i-1]=B[i];
58     solve(B,A1,n<<1);solve(C,A2,n<<1);
59     rep(i,0,n-1) printf("%.3lf\n",B[i].real()-C[n-1-i].real());
60 }
View Code

 

bzoj 3513 idiots

题目大意:

给定n个长度分别为$a_i$的木棒,问随机选择3个木棒能够拼成三角形的概率

思路:

转化成有多少种方案不能拼成然后拿全集去减 用一个多项式表示每个值出现了多少次

然后自己卷自己 卷出来的多项式$F_i$表示有多少种点对组合相加等于$i$ 然后对于每个二倍的值-1(自己加自己也被卷了)

最后每个$/2$ 对于每个值统计多项式的前缀和相乘即可表示以这个值为最长边不能选的方案

 1 #include<iostream>
 2 #include<cstdio>
 3 #include<cstdlib>
 4 #include<cmath>
 5 #include<algorithm>
 6 #include<cstring>
 7 #include<vector>
 8 #include<queue>
 9 #include<complex>
10 #include<map>
11 #define rep(i,s,t) for(register int i=(s);i<=(t);++i)
12 #define dwn(i,s,t) for(register int i=(s);i>=(t);--i)
13 #define ren for(register int i=fst[x];i;i=nxt[i])
14 #define Fill(x,t) memset(x,t,sizeof(x))
15 #define ll long long
16 #define inf 2139062143
17 #define MOD 998244353
18 #define MAXN 100100
19 using namespace std;
20 inline int read()
21 {
22     int x=0,f=1;char ch=getchar();
23     while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
24     while(isdigit(ch)) {x=x*10+ch-'0';ch=getchar();}
25     return x*f;
26 }
27 int n,m;
28 const double pi=acos(-1);
29 struct Cd{double x,y;Cd(double X=0,double Y=0){x=X,y=Y;};};
30 Cd operator + (Cd a,Cd b) {return (Cd){a.x+b.x,a.y+b.y};}
31 Cd operator - (Cd a,Cd b) {return (Cd){a.x-b.x,a.y-b.y};}
32 Cd operator * (Cd a,Cd b) {return (Cd){a.x*b.x-a.y*b.y,a.y*b.x+a.x*b.y};}
33 int rev[MAXN<<2];
34 Cd A[MAXN<<2],B[MAXN<<2];
35 ll f[MAXN<<2],g[MAXN<<2],ans;
36 void fft(Cd *a,int n,int f)
37 {
38     rep(i,0,n-1) if(i<rev[i])swap(a[i],a[rev[i]]);
39     for(int i=1;i<n;i<<=1)
40     {
41         Cd wn(cos(pi/i),f*sin(pi/i));
42         for(int j=0;j<n;j+=i<<1)
43         {
44             Cd w(1,0),x,y;
45             for(int k=0;k<i;k++,w=w*wn)
46                 x=a[k+j],y=a[k+j+i]*w,a[j+k]=x+y,a[j+k+i]=x-y;
47         }
48     }
49     if(f==1) return ;rep(i,0,n-1) a[i].x/=(double)n;
50 }
51 void solve(Cd *a,int sum)
52 {
53     int lg=ceil(log2(sum)),lmt=1<<lg;
54     rep(i,0,lmt-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1));
55     fft(a,lmt,1);rep(i,0,lmt-1) a[i]=a[i]*a[i];
56     fft(a,lmt,-1);rep(i,0,lmt-1) (g[i]+=(ll)(a[i].x+0.5))>>=1;
57 }
58 int main()
59 {
60     int T=read(),mx,x;double res; while(T--) 
61     {
62         Fill(A,0);Fill(f,0);Fill(g,0);mx=ans=0;
63         n=read();rep(i,0,n-1) A[x=read()].x++,g[x<<1]-=1,f[x]++,mx=max(mx,x<<1|1);
64         solve(A,mx);rep(i,0,mx) g[i]+=g[i-1],ans+=f[i]*g[i];
65         res=(double)(6.0*ans)/n/(n-1)/(n-2);printf("%.7lf\n",1.0-res);
66     }
67 }
View Code

 

posted @ 2018-12-11 15:45  jack_yyc  阅读(237)  评论(0编辑  收藏  举报