LOJ 3058 「HNOI2019」白兔之舞——单位根反演+MTT
题目:https://loj.ac/problem/3058
先考虑 n=1 怎么做。令 a 表示输入的 w[1][1] 。
\( ans_t = \sum\limits_{i=0}^{L}C_{L}^{i} a^i [ k|(i-t) ] \)
\(= \frac{1}{k}\sum\limits_{i=0}^{L}C_{L}^{i} a^i \sum\limits_{j=0}^{k-1} w_{k}^{j*(i-t)} \)
\(= \frac{1}{k}\sum\limits_{j=0}^{k-1}w_{k}^{-j*t} \sum\limits_{i=0}^{L}C_{L}^{i} a^i w_{k}^{i*j} \)
\(= \frac{1}{k}\sum\limits_{j=0}^{k-1}w_{k}^{-j*t} (1+a*w_{k}^{j})^{L} \)
这样是 k2 的,就不会了……
考虑卷积。把 -j*t 拆成只和 j 有关的与只和 t 或者 t+j 、t-j 有关的。
注意到 \( j*t = C_{j+t}^{2} - C_{t}^{2} - C_{j}^{2} \) 。考虑 j*t 表示从 j 个里选一个、再从 t 个里选一个;表示成从 (j+t) 里选两个,再减去不合法的,即从 j 个里选了两个或从 t 个里选了两个。
\( ans_t = \frac{1}{k}\sum\limits_{j=0}^{k}w_{k}^{-\binom{j+t}{2}+\binom{j}{2}+\binom{t}{2}} (1+a*w_{k}^{j})^{L} \)
最后那个部分只和 j 有关。所以令 \( c_j = (1+a*w_{k}^{j})^{L} \)
\(= \frac{ w_{k}^{\binom{t}{2}} }{k}\sum\limits_{j=0}^{k-1}w_{k}^{\binom{i}{2}} c_j * w_{k}^{-\binom{j+t}{2}} \)
然后可以卷积。
如果 n>1 ,用矩阵表示 “从 x 用 i 步走到 y ”的方案!仍然要乘组合数。
也就是除了 \( c_j = ( I + A*w_{k}^{j} )^{L} [x,y] \) 之外都没变。其中 A 是输入的矩阵,[x,y] 表示取矩阵的第 x 行第 y 列的值作为 \( c_j \) 。
给矩阵乘一个数字,是给其每个位置都乘。
不开 long double 会变成 0 分。
复习 MTT 。
#include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #define db long double #define ll long long using namespace std; const int N=(1<<18)+5; const db pi2=acos(-1)*2; int n,k,L,x,y,mod,G,bs,len,r[N],c[N],f[N],g[N],wn[N]; int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;} int pw(int x,int k) {int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;} struct cpl{ db x,y; cpl(db x=0,db y=0):x(x),y(y) {} cpl operator+ (const cpl &b)const{return cpl(x+b.x,y+b.y);} cpl operator- (const cpl &b)const{return cpl(x-b.x,y-b.y);} cpl operator* (const cpl &b)const {return cpl(x*b.x-y*b.y,x*b.y+y*b.x);} cpl operator/ (const db &b)const{return cpl(x/b,y/b);} }Ta[N],Tb[N],P[N],Q[N]; cpl cnj(cpl a){return cpl(a.x,-a.y);} struct Mtr{ int a[3][3]; Mtr(){memset(a,0,sizeof a);} Mtr operator* (const Mtr &b)const { Mtr c; for(int i=0;i<n;i++) for(int k=0;k<n;k++) for(int j=0;j<n;j++) c.a[i][j]=(c.a[i][j]+(ll)a[i][k]*b.a[k][j])%mod; return c; } Mtr operator* (const int &b)const { Mtr c; for(int i=0;i<n;i++) for(int j=0;j<n;j++)c.a[i][j]=(ll)a[i][j]*b%mod; return c; } Mtr operator+ (const Mtr &b)const { Mtr c; for(int i=0;i<n;i++) for(int j=0;j<n;j++)c.a[i][j]=upt(a[i][j]+b.a[i][j]); return c; } }A,tA,tA2,I; namespace get_G{ int p[35],tot; void solve() { int k=mod-1; for(int i=2;i*i<=k;i++) if(k%i==0){ p[++tot]=(mod-1)/i; while(k%i==0)k/=i;} for(int i=2;;i++) { bool fg=0; for(int j=1;j<=tot;j++) if(pw(i,p[j])==1){fg=1;break;} if(!fg){G=i;break;} } } } void fft(cpl *a,bool fx) { for(int i=0;i<len;i++) if(i<r[i])swap(a[i],a[r[i]]); for(int R=2;R<=len;R<<=1) { cpl wn=cpl(cos(pi2/R),fx?-sin(pi2/R):sin(pi2/R)); for(int i=0,m=R>>1;i<len;i+=R) { cpl w=cpl(1,0); for(int j=0;j<m;j++,w=w*wn) { cpl x=a[i+j],y=w*a[i+m+j]; a[i+j]=x+y; a[i+m+j]=x-y; } } } if(!fx)return; for(int i=0;i<len;i++)a[i]=a[i]/len; } void MTT(int *a,int *b,int n,int m,int *c) { bs=sqrt(mod); cpl ta,tb,tc,td; for(len=1;len<=n+m;len<<=1); for(int i=0,j=len>>1;i<len;i++) r[i]=(r[i>>1]>>1)+((i&1)?j:0); for(int i=0;i<=n;i++) P[i]=cpl(a[i]/bs,a[i]%bs); for(int i=0;i<=m;i++) Q[i]=cpl(b[i]/bs,b[i]%bs); fft(P,0); fft(Q,0); P[len]=P[0]; Q[len]=Q[0]; for(int i=0;i<len;i++) { ta=(P[i]+cnj(P[len-i]))*cpl(0.5,0); tb=(P[i]-cnj(P[len-i]))*cpl(0,-0.5); tc=(Q[i]+cnj(Q[len-i]))*cpl(0.5,0); td=(Q[i]-cnj(Q[len-i]))*cpl(0,-0.5); Ta[i]=ta*tc+ta*td*cpl(0,1); Tb[i]=tb*tc+tb*td*cpl(0,1); } fft(Ta,1); fft(Tb,1); for(int i=0,lm=n+m,bs2=bs*bs;i<=lm;i++) { int a2=(ll)(Ta[i].x+0.5)%mod;//%mod int b2=(ll)(Ta[i].y+0.5)%mod; int c2=(ll)(Tb[i].x+0.5)%mod; int d2=(ll)(Tb[i].y+0.5)%mod; c[i]=((ll)a2*bs2+(ll)(b2+c2)*bs+d2)%mod; } } int main() { scanf("%d%d%d%d%d%d",&n,&k,&L,&x,&y,&mod); get_G::solve(); x--; y--; for(int i=0;i<n;i++) for(int j=0;j<n;j++)scanf("%d",&A.a[i][j]); wn[0]=1; wn[1]=pw(G,(mod-1)/k); for(int i=2;i<=k;i++)wn[i]=(ll)wn[i-1]*wn[1]%mod;//<=k for(int i=0;i<n;i++)I.a[i][i]=1;// for(int i=0;i<k;i++) { tA=A*wn[i]+I; tA2=I; int tp=L;//tp=L not i!!!!! while(tp) { if(tp&1)tA2=tA2*tA; tA=tA*tA; tp>>=1;} c[i]=tA2.a[x][y]; } for(int i=0;i<k;i++) f[k-1-i]=(ll)wn[(ll)i*(i-1)/2%k]*c[i]%mod; for(int i=0,lm=2*(k-1);i<=lm;i++) g[i]=wn[k-(ll)i*(i-1)/2%k]; MTT(f,g,k-1,2*(k-1),f); int inv=pw(k,mod-2); for(int i=0;i<k;i++) { int ans=(ll)wn[(ll)i*(i-1)/2%k]*inv%mod; ans=(ll)ans*f[k-1+i]%mod; printf("%d\n",ans); } return 0; }