[HNOI2019] 白兔之舞
Problem
有一张顶点数为 \((L+1)\times n\) 的有向图,每个节点用二元组 \((u,v)\) 来表示(\(0\le u\le L,1\le v\le n\)),节点 \((u_1,v_1)\) 到 \((u_2,v_2)\) 有 \(w_{v_1,v_2}\) 条不同的边,当且仅当 \(u_1<u_2\)。
初始时白兔在 \((0,x)\),每次沿着一条路跳到下一个节点,它可以在任意时刻停止(也可以在起点停止),或者第一维到达 \(L\) 也会停止。停止时白兔共跳了 \(m\) 步,第 \(i\) 个元素表示经过的第 \(i\) 条边。
给定 \(k\) 和 \(y(1\le y\le n)\),对于每个 \(t(0\le t<k)\),求有多少种舞曲(假设其长度为 \(m\))满足 \(m\bmod k=t\),且白兔最后停在了坐标第二维为 \(y\) 的顶点。答案对 \(p\)(一个质数)取模。
\(10^8<p<2^{30}\),\(1\le n\le 3\),\(1\le x,y\le n\),\(0\le w_{i,j}<p\),\(1\le k\le 65536\),\(k\mid p-1\),\(1\le L\le 10^8\)。
Sol
由题,我们写出来转移矩阵 \(A\)
列向量 \(X\) 满足 \(X_x=1\),其余为 \(0\)。
发现第一维和第二维在某种意义上独立。
变换 \(m\) 次得到 \(A^mX\),取出来第 \(y\) 项乘上组合数 \(L\choose m\) 即为长度为 \(m\) 的舞曲的种类数。
假设我们只需要求出 \(m\bmod k=t\) 的答案和,不难想到 单位根反演:
推导中暂时略去乘上的列向量 \(X\),只要最后补上去即可。注意到后半部分中相当于要求
由于 \(AI=IA=A\),满足 交换律,故可以使用二项式定理,得到
令 \(y_j=(w^jA+I)^LX\) 的第 \(y\) 项,要求 \(t=j\) 时的答案,则上面式子可以看成
也就是说我们要依次求出来 \(f(w^0),f(w^{-1}),f(w^{-2}),\cdots\)。直接 NTT
但 \(k\not\equiv 2^l\),所以我们需要 Bluestein's Algorithm
来做任意长度的卷积,详见我的另一篇博文 Chirp Z-Transform。
复杂度 \(\mathcal O(n^3k\log L+k\log k)\)。此题毒瘤还要 MTT
。
Code
#include <bits/stdc++.h>
using std::vector;
typedef long long LL;
const double PI = acos(-1);
const int N = (1 << 16) + 5;
int n, k, L, x, y, p, g, w, a[N * 4], b[N * 4], c[N * 4], ans[N];
struct Mat {
int a[3][3], n, m;
int *operator [] (int i) { return a[i]; }
Mat(int n = 0, int m = 0) : n(n), m(m) {
memset(a, 0, sizeof a);
}
} A, X, I;
Mat operator + (Mat a, Mat b) {
for (int i = 0; i < a.n; i++)
for (int j = 0; j < a.m; j++)
a[i][j] = (a[i][j] + b[i][j]) % p;
return a;
}
Mat operator * (Mat a, Mat b) {
Mat c(a.n, b.m);
for (int k = 0; k < a.m; k++)
for (int i = 0; i < a.n; i++)
for (int j = 0; j < b.m; j++)
c[i][j] = (c[i][j] + 1LL * a[i][k] * b[k][j]) % p;
return c;
}
Mat operator * (int k, Mat a) {
for (int i = 0; i < a.n; i++)
for (int j = 0; j < a.m; j++)
a[i][j] = 1LL * a[i][j] * k % p;
return a;
}
Mat qpow(Mat a, int b) {
Mat c(a.n, a.n);
for (int i = 0; i < a.n; i++) c[i][i] = 1;
for (; b; b >>= 1, a = a * a)
if (b & 1) c = c * a;
return c;
}
int qpow(int a, int b) {
int c = 1;
for (; b; b >>= 1, a = 1LL * a * a % p)
if (b & 1) c = 1LL * c * a % p;
return c;
}
struct comp {
double x, y;
comp(double x = 0, double y = 0) : x(x), y(y) {}
};
comp operator + (comp a, comp b) { return {a.x + b.x, a.y + b.y}; }
comp operator - (comp a, comp b) { return {a.x - b.x, a.y - b.y}; }
comp operator * (comp a, comp b) { return comp{a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x}; }
comp conj(comp a) { return {a.x, -a.y}; }
comp W[N * 4];
int pw[N * 4], ipw[N * 4];
void prework(int n) {
for (int i = 1; i < n; i <<= 1)
for (int j = 0; j < i; j++)
W[i + j] = {cos(PI / i * j), sin(PI / i * j)};
pw[0] = pw[1] = ipw[0] = ipw[1] = 1, pw[2] = w, ipw[2] = qpow(w, p - 2);
for (int i = 3; i < n; i++)
pw[i] = 1LL * pw[i - 1] * pw[2] % p, ipw[i] = 1LL * ipw[i - 1] * ipw[2] % p;
for (int i = 3; i < n; i++)
pw[i] = 1LL * pw[i] * pw[i - 1] % p, ipw[i] = 1LL * ipw[i] * ipw[i - 1] % p;
}
void fft(comp *a, int n, int op) {
static int rev[N * 4];
for (int i = 0; i < n; i++)
if ((rev[i] = rev[i >> 1] >> 1 | (i & 1 ? n >> 1 : 0)) > i) std::swap(a[i], a[rev[i]]);
for (int q = 1; q < n; q <<= 1)
for (int p = 0; p < n; p += q << 1)
for (int i = 0; i < q; i++) {
comp t = W[q + i] * a[p + q + i];
a[p + q + i] = a[p + i] - t; a[p + i] = a[p + i] + t;
}
if (op) return;
for (int i = 0; i < n; i++) a[i].x /= n, a[i].y /= n;
std::reverse(a + 1, a + n);
}
int getsz(int n) {
int x = 1;
while (x < n) x <<= 1;
return x;
}
void conv(int *x, int *y, int *z, int n) {
static comp a[N * 4], b[N * 4], da[N * 4], db[N * 4], dc[N * 4], dd[N * 4];
for (int i = 0; i < n; i++)
a[i] = comp(x[i] >> 15, x[i] & 32767), b[i] = comp(y[i] >> 15, y[i] & 32767);
fft(a, n, 1), fft(b, n, 1);
for (int i = 0; i < n; i++) {
int j = (n - 1) & (n - i);
static comp a1, a2, b1, b2;
a1 = (a[i] + conj(a[j])) * comp{0.5, 0};
a2 = (a[i] - conj(a[j])) * comp{0, -0.5};
b1 = (b[i] + conj(b[j])) * comp{0.5, 0};
b2 = (b[i] - conj(b[j])) * comp{0, -0.5};
da[i] = a1 * b1, db[i] = a1 * b2, dc[i] = a2 * b1, dd[i] = a2 * b2;
}
for (int i = 0; i < n; i++)
a[i] = da[i] + db[i] * comp{0, 1}, b[i] = dc[i] + dd[i] * comp{0, 1};
fft(a, n, 0), fft(b, n, 0);
for (int i = 0; i < n; i++) {
int ax = (LL)(a[i].x + 0.5) % p, ay = (LL)(a[i].y + 0.5) % p, bx = (LL)(b[i].x + 0.5) % p, by = (LL)(b[i].y + 0.5) % p;
z[i] = (((LL)ax << 30) + ((LL)(ay + bx) << 15) + by) % p;
}
}
bool check(int n) {
for (int i = 2; i * i <= n; i++)
if (!(n % i)) return 0;
return 1;
}
void getG() {
vector<int> fac;
int tmp = p - 1;
for (int i = 2; !check(tmp); i++)
while (tmp % i == 0) fac.push_back(i), tmp /= i;
if (tmp != 1) fac.push_back(tmp);
for (;; g++) {
int ok = 1;
for (int i = 0; i < fac.size(); i++)
if (qpow(g, (p - 1) / fac[i]) == 1) { ok = 0; break; }
if (ok) break;
}
}
int main() {
scanf("%d%d%d%d%d%d", &n, &k, &L, &x, &y, &p);
A = Mat(n, n);
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
scanf("%d", &A[j][i]);
X = Mat(n, 1), X[x - 1][0] = 1;
I = Mat(n, n);
for (int i = 0; i < n; i++) I[i][i] = 1;
g = 1; getG();
w = qpow(g, (p - 1) / k);
int l = getsz(k * 2); prework(l);
for (int i = 0; i < k; i++)
a[i] = 1LL * (qpow(qpow(w, i) * A + I, L) * X)[y - 1][0] * ipw[i] % p;
std::reverse(a, a + k + 1);
for (int i = 0; i < k * 2; i++) b[i] = pw[i];
conv(a, b, c, l);
for (int i = 0; i < k; i++)
ans[i ? k - i : 0] = 1LL * c[k + i] * ipw[i] % p * qpow(k, p - 2) % p;
for (int i = 0; i < k; i++)
printf("%d\n", ans[i]);
return 0;
}