bzoj 3992 [SDOI2015] 序列统计 —— NTT (循环卷积+快速幂)

题目:https://www.lydsy.com/JudgeOnline/problem.php?id=3992

(学习NTT:https://riteme.github.io/blog/2016-8-22/ntt.html

https://www.cnblogs.com/Mychael/p/9297652.html

http://blog.miskcoo.com/2015/04/polynomial-multiplication-and-fast-fourier-transform#i-15 )

首先,如果把方案数和乘积分别放在系数和次数上,就可以用多项式做了;

方案数放在系数上好说,但次数是相加的,如何表示乘积?

考虑乘积与加法的关系 —— 幂的相乘就是指数相加;

所以可以找出乘积的模数 m 的原根,用其次数相加代表乘积,这个次数好像被称为“指标”;

构造出多项式,由于要取模,所以用 NTT 做;

也就是要把初始的多项式做 n 次幂,可以用快速幂,但注意累乘起来的是系数而不是点值;

指标从0开始或从1开始都可以,也就是把 0 次方作为 1 和把 m-1 次方作为 1 的区别,对应系数的时候要根据这个注意一下(代码中注释里的方案也可);

初始化一个多项式并不是把每个系数都赋值1!而只有第0项是1,这样别的多项式乘过来还是那个多项式;

然后要特别注意读入时去掉0!因为原根系列中没有模出0的,所以以原根为基础的 NTT 算的时候不能考虑0,而反正最后要求的方案中,x >= 1,一旦有0,乘积就是0了,所以0对答案没有影响,就当没给这个数算即可;

一下午的心血...

代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
int const xn=(1<<14),xm=8005,mod=1004535809;
int n,m,rev[xn],g,a[xn],b[xn],lim,r[xm],cnt,pri[xm],inv;
int rd()
{
  int ret=0,f=1; char ch=getchar();
  while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();}
  while(ch>='0'&&ch<='9')ret=(ret<<3)+(ret<<1)+ch-'0',ch=getchar();
  return f?ret:-ret;
}
int pw(ll a,int b,int md)
{
  ll ret=1;
  for(;b;b>>=1,a=(a*a)%md)if(b&1)ret=(ret*a)%md;
  return ret;
}
void div(int x)
{
  for(int i=2;i*i<=x;i++)
    {
      if(x%i)continue;
      pri[++cnt]=i; while(x%i==0)x/=i;
    }
  if(x>1)pri[++cnt]=x;
}
void init()
{
  lim=1; int l=0;
  while(lim<=m+m)lim<<=1,l++;
  //while(lim<=2*(m-1))lim<<=1,l++;
  for(int i=0;i<lim;i++)
    rev[i]=((rev[i>>1]>>1)|((i&1)<<(l-1)));
  inv=pw(lim,mod-2,mod);

  if(m==2){g=1; return;}
  div(m-1);
  for(g=2;;g++)
    {
      bool f=0;
      for(int j=1;j<=cnt;j++)
    if(pw(g,(m-1)/pri[j],m)==1){f=1; break;}
      if(!f)break;
    }
  for(int i=1,k=g;i<m;i++,k=(ll)k*g%m)r[k]=i;
  //for(int i=0,k=1;i<m-1;i++,k=(ll)k*g%m)r[k]=i;//k=1
}
int upt(int x){while(x>=mod)x-=mod; while(x<0)x+=mod; return x;}
void ntt(int *a,int tp)
{
  for(int i=0;i<lim;i++)
    if(i<rev[i])swap(a[i],a[rev[i]]);
  for(int mid=1;mid<lim;mid<<=1)
    {
      int wn=pw(3,(mod-1)/(mid<<1),mod);
      if(tp==-1)wn=pw(wn,mod-2,mod);//
      for(int j=0,len=(mid<<1);j<lim;j+=len)
    {
      int w=1;
      for(int k=0;k<mid;k++,w=(ll)w*wn%mod)
        {
          int x=a[j+k],y=(ll)w*a[j+mid+k]%mod;
          a[j+k]=upt(x+y); a[j+mid+k]=upt(x-y);
        }
    }
    }
}
void pww()
{
  ntt(a,1);
  for(int i=0;i<lim;i++)a[i]=(ll)a[i]*a[i]%mod;
  ntt(a,-1); 
  for(int i=0;i<lim;i++)a[i]=(ll)a[i]*inv%mod;
  for(int i=m;i<lim;i++)a[i%m+1]=upt(a[i%m+1]+a[i]),a[i]=0;//%m+1
  //for(int i=m-1;i<lim;i++)a[i%(m-1)]=upt(a[i%(m-1)]+a[i]),a[i]=0;
}
void mul()
{
  ntt(a,1); ntt(b,1);//
  for(int i=0;i<lim;i++)b[i]=(ll)b[i]*a[i]%mod;
  ntt(a,-1); ntt(b,-1);//
  for(int i=0;i<lim;i++)
    a[i]=(ll)a[i]*inv%mod,b[i]=(ll)b[i]*inv%mod;
  
  for(int i=m;i<lim;i++)
    a[i%m+1]=upt(a[i%m+1]+a[i]),a[i]=0,
      b[i%m+1]=upt(b[i%m+1]+b[i]),b[i]=0;
  /*
  for(int i=m-1;i<lim;i++)
    a[i%(m-1)]=upt(a[i%(m-1)]+a[i]),a[i]=0,
      b[i%(m-1)]=upt(b[i%(m-1)]+b[i]),b[i]=0;
  */
}
int main()
{
  n=rd(); m=rd(); init();
  int p=rd(),num=rd();
  for(int i=1,x;i<=num;i++)
    {
      x=rd();
      if(x)a[r[x]]=1;//x!=0 !!
    }
  int t=n;
  //for(int i=0;i<lim;i++)b[i]=1;
  b[0]=1;//!
  for(;t;t>>=1,pww())if(t&1)mul();
  printf("%d\n",b[r[p]]);
  return 0;
}

 

posted @ 2018-11-28 18:23  Zinn  阅读(455)  评论(0编辑  收藏  举报