洛谷 P5293 [HNOI2019]白兔之舞
有一张顶点数为 \((L+1)\times n\) 的有向图。这张图的每个顶点由一个二元组\((u,v)\)表示\((0\le u\le L,1\le v\le n)\)。
这张图不是简单图,对于任意两个顶点 \((u_1,v_1)(u_2,v_2)\),如果 \(u_1<u_2\),则从 \((u_1,v_1)\) 到 \((u_2,v_2)\) 一共有 \(w[v_1][v_2]\) 条不同的边,如果 \(u_1\ge u_2\) 则没有边。
白兔将在这张图上上演一支舞曲。白兔初始时位于该有向图的顶点 \((0,x)\)。
白兔将会跳若干步。每一步,白兔会从当前顶点沿任意一条出边跳到下一个顶点。白兔可以在任意时候停止跳舞(也可以没有跳就直接结束)。当到达第一维为 \(L\) 的顶点就不得不停止,因为该顶点没有出边。
假设白兔停止时,跳了 \(m\) 步,白兔会把这只舞曲给记录下来成为一个序列。序列的第 \(i\) 个元素为它第 \(i\) 步经过的边。
问题来了:给定正整数 \(k\) 和 \(y\)(\(1\le y\le n\)),对于每个 \(t\)(\(0\le t<k\)),求有多少种舞曲(假设其长度为 \(m\))满足 \(m \bmod k=t\),且白兔最后停在了坐标第二维为 \(y\) 的顶点?
两支舞曲不同定义为它们的长度(\(m\))不同或者存在某一步它们所走的边不同。
输出的结果对 \(p\) 取模。
对于全部数据,\(p\) 为一个质数,\(10^8<p<2^{30}\),\(1\le n\le 3\),\(1\le x\le n\),\(1\le y\le n\),\(0\le w(i,j)<p\),\(1\le k\le 65536\),\(k\) 为 \(p-1\) 的约数,\(1\le L\le 10^8\)。
首先可以考虑dp,设 \(f_{i,j}\) 表示走了 \(i\) 步,最后第二维停在了 \(j\) 上的方案数。
这个东西不好dp,再设一个 \(g_{i,j}\) 表示挨着走了 \(i\) 个格子,最后第二维走到了 \(j\) 上的方案数,这个就可以dp了,有:
这个是非常可以矩阵优化的,那么就可以写成:
然后 \(f_{i,j}\) 可以看作从 \(L\) 个格子中选了了 \(i\) 个落脚点,就有:
最后我们的答案也就变成了:
这东西熟啊,直接单位根反演,就有:
把 \(f_{m,y}\) 替换成 \(G_m\) 的系数的形式
其中 \(I\) 是单位矩阵,设 \(F_d=[y]G_0(\omega_k^dW+I)^L\) ,那么这个东西可以通过矩阵快速幂预处理出来,然后就有:
这个是个循环卷积,我们用 \({i+j\choose 2}-{i\choose2}-{j\choose 2}\) 来替换 \(\omega_k\) 的系数就有:
发现这是个差卷积,然后模数不固定,还要做mtt。
Code
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
const int N = 65536;
const int M = 1e6;
const long double Pi = acos(-1.0);
using namespace std;
int n,k,L,x,y,p,g,rev[M + 5],maxn,lg,prime[N + 5],pcnt,ow[N + 5],A[M + 5],B[M + 5],C[M + 5];
struct Matrix
{
int a[4][4];
}G,W,nw,I;
Matrix operator *(Matrix a,Matrix b)
{
Matrix c;
for (int i = 1;i <= n;i++)
for (int j = 1;j <= n;j++)
c.a[i][j] = 0;
for (int i = 1;i <= n;i++)
for (int j = 1;j <= n;j++)
for (int k = 1;k <= n;k++)
c.a[i][j] += 1ll * a.a[i][k] * b.a[k][j] % p,c.a[i][j] %= p;
return c;
}
struct node
{
double x,y;
node conj(){return (node){x,-y};}
}w[M + 5],c[M + 5],d[M + 5],x1[M + 5],x2[M + 5],x3[M + 5],aa[M + 5],bb[M + 5],cc[M + 5],dd[M + 5],conj;
node operator +(node a,node b){return (node){a.x + b.x,a.y + b.y};}
node operator -(node a,node b){return (node){a.x - b.x,a.y - b.y};}
node operator *(node a,node b){return (node){a.x * b.x - a.y * b.y,a.x * b.y + a.y * b.x};}
int mypow(int a,int x,int p){int s = 1;for (;x;x & 1 ? s = 1ll * s * a % p : 0,a = 1ll * a * a % p,x >>= 1);return s;}
void prework(int n)
{
maxn = 1;lg = 0;
while (maxn <= n)
maxn <<= 1,lg++;
for (int i = 0;i < maxn;i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lg - 1);
}
int getphi(int n)
{
int ans = n;
for (int i = 2;i * i <= n;i++)
if (n % i == 0)
{
while (n % i == 0)
{
n /= i;
ans = ans / i * (i - 1);
}
}
if (n != 1)
ans = ans / n * (n - 1);
return ans;
}
void getprime(int n)
{
for (int i = 2;i * i <= n;i++)
if (n % i == 0)
{
prime[++pcnt] = i;
while (n % i == 0)
n /= i;
}
if (n != 1)
prime[++pcnt] = n;
}
int get(int n)
{
int phi = getphi(n);
getprime(phi);
for (int i = 1;i < n;i++)
if (mypow(i,phi,n) == 1)
{
int fl = 1;
for (int j = 1;j <= pcnt;j++)
if (mypow(i,phi / prime[j],n) == 1)
{
fl = 0;
break;
}
if (fl)
return i;
}
}
void fft(node *a,int typ)
{
for (int i = 0;i < maxn;i++)
if (i < rev[i])
swap(a[i],a[rev[i]]);
for (int i = 1;i < maxn;i <<= 1)
for (int j = 0;j < maxn;j += i << 1)
for (int k = 0;k < i;k++)
{
node x = a[j + k],t = (node){w[k + i].x,w[k + i].y * typ} * a[j + k + i];
a[j + k] = x + t;
a[j + k + i] = x - t;
}
if (typ == -1)
for (int i = 0;i < maxn;i++)
a[i].x /= maxn,a[i].y /= maxn;
}
Matrix mpow(Matrix a,int x)
{
Matrix s = I;
while (x)
{
if (x & 1)
s = s * a;
a = a * a;
x >>= 1;
}
return s;
}
void getmatrix(int d)
{
memset(G.a,0,sizeof(G.a));
G.a[1][x] = 1;
memset(nw.a,0,sizeof(nw.a));
for (int i = 1;i <= n;i++)
for (int j = 1;j <= n;j++)
nw.a[i][j] = 1ll * ow[d] * W.a[i][j] % p + I.a[i][j];
G = G * mpow(nw,L);
}
int C2(int n){return 1ll * n * (n - 1) / 2 % k;}
void mtt(int *a,int *b,int n,int p)
{
for (int i = 0;i < n;i++)
a[i] %= p,b[i] %= p;
int bs = 32768;
for (int i = 0;i < n;i++)
c[i] = (node){a[i] / bs,a[i] % bs};
fft(c,1);
d[0] = c[0].conj();
for (int i = 1;i < maxn;i++)
d[i] = c[maxn - i].conj();
for (int i = 0;i < maxn;i++)
aa[i] = (c[i] + d[i]) * (node){0.5,0},bb[i] = (c[i] - d[i]) * (node){0,-0.5};
for (int i = 0;i < maxn;i++)
c[i] = d[i] = (node){0,0};
for (int i = 0;i < n;i++)
c[i] = (node){b[i] / bs,b[i] % bs};
fft(c,1);
d[0] = c[0].conj();
for (int i = 1;i < maxn;i++)
d[i] = c[maxn - i].conj();
for (int i = 0;i < maxn;i++)
cc[i] = (c[i] + d[i]) * (node){0.5,0},dd[i] = (c[i] - d[i]) * (node){0,-0.5};
for (int i = 0;i < maxn;i++)
x1[i] = aa[i] * cc[i],x2[i] = aa[i] * dd[i] + cc[i] * bb[i],x3[i] = bb[i] * dd[i];
for (int i = 0;i < maxn;i++)
x1[i] = x1[i] + x3[i] * (node){0,1};
fft(x1,-1);
fft(x2,-1);
for (int i = 0;i < n;i++)
a[i] = ((1ll * ((long long)(x1[i].x + 0.1)) % p * bs % p * bs % p + 1ll * ((long long)(x2[i].x + 0.1) % p) * bs % p) % p + ((long long)(x1[i].y + 0.1)) % p) % p;
}
int main()
{
scanf("%d%d%d%d%d%d",&n,&k,&L,&x,&y,&p);
for (int i = 1;i <= n;i++)
for (int j = 1;j <= n;j++)
scanf("%d",&W.a[i][j]);
g = get(p);
ow[1] = mypow(g,(p - 1) / k,p);
for (int i = 2;i <= k;i++)
ow[i % k] = 1ll * ow[i - 1] * ow[1] % p;
prework(k * 4);
for (int i = 1;i < maxn;i <<= 1)
for (int j = 0;j < i;j++)
w[i + j] = (node){cos(Pi * j / i),sin(Pi * j / i)};
for (int i = 1;i <= 3;i++)
I.a[i][i] = 1;
int lim = k * 2;
for (int d = 0;d < lim;d++)
A[d] = ow[(k - C2(d)) % k];
for (int d = 0;d < k;d++)
{
getmatrix(d);
B[d] = 1ll * ow[C2(d)] * G.a[1][y] % p;
}
reverse(A,A + lim);
mtt(A,B,lim,p);
reverse(A,A + lim);
int ik = mypow(k,p - 2,p);
for (int d = 0;d < k;d++)
{
A[d] = 1ll * ik * ow[C2(d)] % p * A[d] % p;
printf("%d\n",(A[d] + p) % p);
}
return 0;
}