洛谷P5050 【模板】多项式多点求值

https://www.luogu.org/problemnew/show/P5050

给定多项式A(x),求$A(x_l)$,$A(x_{l+1})$,..,$A(x_r)$

分治:(如果r-l+1=1,直接O(deg(A))暴力求出即可)

首先设$mid=\lfloor\frac{l+r}{2}\rfloor$,$P^{[0]}(x)=\prod_{i=l}^{mid}(x-x_i)$,$P^{[1]}(x)=\prod_{i=mid+1}^{r}(x-x_i)$

以[l,mid]的求值为例:设$A^{[0]}(x)=A(x)\,mod\,P^{[0]}(x)$

即$A(x)=P^{[0]}(x)B^{[0]}(x)+A^{[0]}(x)$($B^{[0]}$为某个多项式)

可以发现,将$x_l$,$x_{l+1}$,..,$x_{mid}$带入$P^{[0]}(x)$,值都为0

因此对于$l<=i<=mid$,$A(x_i)=A^{[0]}(x_i)$,递归下去算就行;[mid+1,r]的求值同理

这个P可以在分治过程中处理出来

时间复杂度大概是$O(n\,log^2\,n)$(未区分n=r-l+1,m=deg(A))

版本1:基于版本1,加了小范围暴力,预处理了P方便快速插值

注意:这种分治FFT的题,NTT里面wn需要预处理否则可能慢很多(听说多一个log,没有仔细分析)!

  1 #prag\
  2 ma GCC optimize(2)
  3 #include<cstdio>
  4 #include<algorithm>
  5 #include<cstring>
  6 #include<vector>
  7 #include<cmath>
  8 using namespace std;
  9 #define fi first
 10 #define se second
 11 #define mp make_pair
 12 #define pb push_back
 13 typedef long long ll;
 14 typedef unsigned long long ull;
 15 const int md=998244353;
 16 const int N=131072;
 17 #define delto(a,b) ((a)-=(b),((a)<0)&&((a)+=md))
 18 inline int del(int a,int b)
 19 {
 20     a-=b;
 21     return a<0?a+md:a;
 22 }
 23 int rev[N];
 24 void init(int len)
 25 {
 26     int bit=0,i;
 27     while((1<<(bit+1))<=len)    ++bit;
 28     for(i=1;i<len;++i)
 29         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
 30 }
 31 ull poww(ull a,ull b)
 32 {
 33     ull ans=1;
 34     for(;b;b>>=1,a=a*a%md)
 35         if(b&1)
 36             ans=ans*a%md;
 37     return ans;
 38 }
 39 int inv[300011],pw_1[300011],pw_2[300011];
 40 void dft(int *a,int len,int idx)//要求len为2的幂
 41 {
 42     int i,j,k,t1,t2;ull wn,wnk;
 43     for(i=0;i<len;++i)
 44         if(i<rev[i])
 45             swap(a[i],a[rev[i]]);
 46     for(i=1;i<len;i<<=1)
 47     {
 48         wn=idx==1?pw_1[i]:pw_2[i];
 49         //wn=poww(idx==1?3:332748118,(md-1)/(i<<1));
 50         for(j=0;j<len;j+=(i<<1))
 51         {
 52             wnk=1;
 53             for(k=j;k<j+i;++k,wnk=wnk*wn%md)
 54             {
 55                 t1=a[k];t2=a[k+i]*wnk%md;
 56                 a[k]+=t2;
 57                 (a[k]>=md)&&(a[k]-=md);
 58                 a[k+i]=t1-t2;
 59                 (a[k+i]<0)&&(a[k+i]+=md);
 60             }
 61         }
 62     }
 63     if(idx==-1)
 64     {
 65         ull ilen=inv[len];
 66         for(i=0;i<len;++i)
 67             a[i]=a[i]*ilen%md;
 68     }
 69 }
 70 void p_inv(int *f,int *g,int len)//g=f^(-1);f,g数组的长度不小于2len(需要足够长用于临时存放元素);要求len是2的幂
 71 {
 72     static int t1[N],t2[N];
 73     g[0]=poww(f[0],md-2);
 74     for(int i=2,j;i<=len;i<<=1)
 75     {
 76         memcpy(t1,f,sizeof(int)*i);
 77         memcpy(t2,g,sizeof(int)*(i>>1));
 78         memset(t2+(i>>1),0,sizeof(int)*(i>>1));
 79         init(i);
 80         dft(t1,i,1);dft(t2,i,1);
 81         for(j=0;j<i;++j)
 82             t1[j]=ull(t1[j])*t2[j]%md;
 83         dft(t1,i,-1);
 84         for(j=0;j<(i>>1);++j)
 85             t1[j]=t1[j+(i>>1)];
 86         memset(t1+(i>>1),0,sizeof(int)*(i>>1));
 87         dft(t1,i,1);
 88         for(j=0;j<i;++j)
 89             t1[j]=ull(t1[j])*t2[j]%md;
 90         dft(t1,i,-1);
 91         for(j=i>>1;j<i;++j)
 92             g[j]=md-t1[j-(i>>1)];
 93     }
 94 }
 95 inline void p_de(int *f,int len)//derivative求导;f=f'
 96 {
 97     for(int i=0;i<len-1;++i)
 98         f[i]=ull(i+1)*f[i+1]%md;
 99     f[len-1]=0;
100 }
101 inline void p_in(int *f,int len)//integral积分;f=?f
102 {
103     for(int i=len-1;i>=1;--i)
104         f[i]=ull(f[i-1])*inv[i]%md;
105     f[0]=0;
106 }
107 void p_ln(int *f,int len)//要求len为2的幂,f[0]=1
108 {
109     static int t3[N];
110     p_inv(f,t3,len);p_de(f,len);
111     init(len<<1);
112     dft(f,len<<1,1);dft(t3,len<<1,1);
113     for(int i=0;i<(len<<1);++i)
114         f[i]=ull(f[i])*t3[i]%md;
115     dft(f,len<<1,-1);p_in(f,len);
116 }
117 void p_exp(int *f,int *g,int len)//要求len为2的幂,f[0]=0
118 {
119     static int t1[N],t2[N];
120     g[0]=1;
121     for(int i=2,j;i<=len;i<<=1)
122     {
123         memcpy(t1,g,sizeof(int)*(i>>1));
124         memset(t1+(i>>1),0,sizeof(int)*(i>>1));
125         p_ln(t1,i);
126         for(j=0;j<(i>>1);++j)
127             t1[j]=del(f[j+(i>>1)],t1[j+(i>>1)]);
128         memset(t1+(i>>1),0,sizeof(int)*(i>>1));
129         init(i);
130         dft(t1,i,1);
131         memcpy(t2,g,sizeof(int)*(i>>1));
132         memset(t2+(i>>1),0,sizeof(int)*(i>>1));
133         dft(t2,i,1);
134         for(j=0;j<i;++j)
135             t1[j]=ull(t1[j])*t2[j]%md;
136         dft(t1,i,-1);
137         for(j=i>>1;j<i;++j)
138             g[j]=t1[j-(i>>1)];
139     }
140 }
141 void p_div(int *a,int *b,int *c,int n,int m)//c=a/b;deg(a)=n,deg(b)=m,deg(c)=n-m;a,b无前导0;n>=m
142 {
143     reverse(a,a+n+1);reverse(b,b+m+1);
144     int x=n-m+1,t=1;
145     for(;t<x;t<<=1);
146     memset(b+m+1,0,sizeof(int)*max(t-m-1,0));
147     p_inv(b,c,t);
148     memset(c+x,0,sizeof(int)*((t<<1)-x));
149     memset(a+x,0,sizeof(int)*((t<<1)-x));
150     init(t<<1);
151     dft(a,t<<1,1);dft(c,t<<1,1);
152     for(int i=0;i<(t<<1);++i)
153         c[i]=ull(c[i])*a[i]%md;
154     dft(c,t<<1,-1);
155     memset(c+(n-m+1),0,sizeof(int)*((t<<1)-n+m-1));
156     reverse(c,c+x);
157 }
158 void p_divmod(int *a,int *b,int *c,int *d,int n,int m)//c=a/b,d=a%b,deg(d)=(<=)m-1;其余同上
159 {
160     static int t1[N];
161     memcpy(d,a,sizeof(int)*(m+1));
162     int x=n+1,t=1;
163     for(;t<x;t<<=1);
164     memcpy(t1,b,sizeof(int)*(m+1));
165     memset(t1+m+1,0,sizeof(int)*max(t-m-1,0));
166     p_div(a,b,c,n,m);
167     memcpy(a,c,sizeof(int)*(n-m+1));
168     memset(a+n-m+1,0,sizeof(int)*(t-n+m-1));
169     init(t);
170     dft(a,t,1);dft(t1,t,1);
171     for(int i=0;i<t;++i)
172         t1[i]=ull(t1[i])*a[i]%md;
173     dft(t1,t,-1);
174     for(int i=0;i<=m;++i)
175         delto(d[i],t1[i]);
176 }
177 namespace P_me
178 {
179     int *ta[N];//用线段树的方法给递归的每一层一个编号,ta[i]表示编号为i的层的P函数的各项系数
180     int data[N*40],*tp;//内存池
181     int *a,*x,*y;
182 #define LC (u<<1)
183 #define RC (u<<1|1)
184     int mt1[N];
185     const int T=200;//小范围暴力阀值
186     void _p_me1(int l,int r,int u)//计算(x-x_l)(x-x_{l+1})..(x-x_r)并存下来
187     {
188         if(r-l<=T)
189         {
190             int i,j;
191             tp[0]=1;
192             for(i=l;i<=r;++i)
193             {
194                 tp[i-l+1]=tp[i-l];
195                 for(j=i-l;j>=1;--j)
196                 {
197                     tp[j]=(ull(tp[j])*(md-x[i])+tp[j-1])%md;
198                 }
199                 tp[0]=ull(tp[0])*(md-x[i])%md;
200             }
201             ta[u]=tp;tp+=r-l+2;
202             return;
203         }
204         int mid=(l+r)>>1;
205         _p_me1(l,mid,LC);_p_me1(mid+1,r,RC);
206         int x=r-l+2,t=1;//x=(mid-l+1)+(r-mid)+1
207         for(;t<x;t<<=1);
208         memcpy(mt1,ta[LC],sizeof(int)*(mid-l+2));
209         memset(mt1+mid-l+2,0,sizeof(int)*(t-mid+l-2));
210         memcpy(tp,ta[RC],sizeof(int)*(r-mid+1));
211         memset(tp+r-mid+1,0,sizeof(int)*(t-r+mid-1));
212         init(t);
213         dft(mt1,t,1);dft(tp,t,1);
214         for(int i=0;i<t;++i)
215             tp[i]=ull(tp[i])*mt1[i]%md;
216         dft(tp,t,-1);
217         ta[u]=tp;tp+=r-l+2;
218     }
219     int mt2[N],mt3[N];
220     void _p_me2(int *a,int n,int l,int r,int u)//a是A的系数,deg(A)<=n;求A(x_l)到A(x_r),放入y_l到y_r
221     {
222         if(r-l<=T)
223         {
224             int t,i,j;
225             for(i=l;i<=r;++i)
226             {
227                 t=a[n];
228                 for(j=n-1;j>=0;--j)
229                     t=(ull(t)*x[i]+a[j])%md;
230                 y[i]=t;
231             }
232             return;
233         }
234         int x=(n+1)<<1,t=1;
235         for(;t<x;t<<=1);
236         int mt4[t];//根据需要改成new?
237         int mid=(l+r)>>1,n1;
238         memcpy(mt1,a,sizeof(int)*(n+1));
239         for(n1=n;n1>=0 && mt1[n1]==0;)    --n1;
240         if(n1<0)
241         {
242             memset(y+l,0,sizeof(int)*(r-l+1));
243             return;
244         }
245         memcpy(mt2,ta[LC],sizeof(int)*(mid-l+2));
246         if(n1<mid-l+1)
247         {
248             memcpy(mt4,mt1,sizeof(int)*(n1+1));
249             _p_me2(mt4,n1,l,mid,LC);
250         }
251         else
252         {
253             p_divmod(mt1,mt2,mt3,mt4,n1,mid-l+1);
254             _p_me2(mt4,mid-l,l,mid,LC);    
255         }
256         memcpy(mt1,a,sizeof(int)*(n+1));
257         for(n1=n;n1>=0 && mt1[n1]==0;)    --n1;
258         memcpy(mt2,ta[RC],sizeof(int)*(r-mid+1));
259         if(n1<r-mid)
260         {
261             memcpy(mt4,mt1,sizeof(int)*(n1+1));
262             _p_me2(mt4,n1,mid+1,r,RC);
263         }
264         else
265         {
266             p_divmod(mt1,mt2,mt3,mt4,n1,r-mid);
267             _p_me2(mt4,r-mid-1,mid+1,r,RC);
268         }
269     }
270     void p_multieval(int *a0,int *x0,int *y0,int n,int m)//deg(a)=n,x有m个数
271     {
272         tp=data;
273         a=a0;x=x0;y=y0;
274         _p_me1(0,m-1,1);
275         _p_me2(a,n,0,m-1,1);
276     }
277 }
278 using P_me::p_multieval;
279 int a[N],x[N],y[N];
280 int n,m;
281 int main()
282 {
283     int i;
284     inv[1]=1;
285     for(i=2;i<=300000;++i)
286         inv[i]=ull(md-md/i)*inv[md%i]%md;
287     for(i=1;i<300000;i<<=1)
288     {
289         pw_1[i]=poww(3,(md-1)/(i<<1));
290         pw_2[i]=poww(332748118,(md-1)/(i<<1));
291     }
292     //n=100000;m=100000;
293     scanf("%d%d",&n,&m);
294     for(i=0;i<=n;++i)
295         //a[i]=rand()%md;
296         scanf("%d",a+i);
297     for(i=0;i<m;++i)
298         //x[i]=rand()%md;
299         scanf("%d",x+i);
300     p_multieval(a,x,y,n,m);
301     for(i=0;i<m;++i)
302         printf("%d\n",y[i]);
303     return 0;
304 }
View Code

 

posted @ 2019-03-29 11:33  hehe_54321  阅读(340)  评论(0编辑  收藏  举报
AmazingCounters.com