[解题报告]【UNR #4】校园闲逛
题意
\(n\) 个点 \(m\) 条边的带权有向图, \(Q\) 次询问, 求出 \(x\) 到 \(y\) 之间长度为 \(v\) 的路径条数.
数据范围
\(n \le 8,\ m \le 300000,\ max_v \le 65000,\ Q \le 1000\).
图的边权大于等于 1.
思路
路径长度为 \(v\) 显然是一个背包的形式, 可以用生成函数搞.
图的限制可以考虑用矩阵解决.
那么我们设两个元素为多项式的矩阵.
矩阵 \(F\), 其中的元素为 \(i\) 到 \(j\) 的长度为 \(v\) 的路径条数的生成函数.
矩阵 \(G\), 其中的元素为 \(i\) 到 \(j\) 的长度为 \(v\) 的边数的生成函数.
那么可以得到
\[F = GF + I
\]
(\(I\) 为单位矩阵, 表示待在原地不走的情况.)
从而得到
\[F = \frac{I}{I - G}
\]
因为题目保证图的边权大于等于 1, 所以 G 中的元素的零次项系数都为 0, 所以可以直接对 \(I - G\) 求逆.
矩阵求逆的时候我们可以利用点值优化多项式的运算 (需要注意的是 FFT 是循环卷积). 总共需要进行 \(O(n)\) 次多项式求逆, \(O(n^2)\) 次 DFT 和 IDFT, \(O(n^3)\) 次多项式加减法, 总时间复杂度为 \(O(n v\log v + n^2 v \log v + n^3 v)\). 具体实现可以看代码.
后序
由于上述做法常数比较大, 所以需要各种多项式卡常技巧. 比如预处理单位根, 用 unsigned long long 减少取模次数等等.
建议不要封装结构体, 写丑了会比较慢.
貌似还有一种多项式套矩阵的做法, 常数会更小, 改天再搞吧 (gugugu).
代码
#include <cstdio>
#include <cstring>
#include <ctime>
#include <iostream>
using namespace std;
typedef long long ll;
const int _ = 8 + 1;
const int __ = (1 << 17) + 1;
const int mod = 998244353, rt = 3;
int n, m, numE, numQ, tot, tt;
int Pw(int a, int p) {
int res = 1;
while (p) {
if (p & 1) res = (ll)res * a % mod;
a = (ll)a * a % mod;
p >>= 1;
}
return res;
}
int num[__], pwrt[2][__], inv[__], w[__];
unsigned long long q[__];
void Pre() {
tot = 1; while (tot <= n + n) tot <<= 1;
tt = (tot >> 1), inv[1] = 1;
for (int i = 2; i <= tot; ++i) inv[i] = (ll)inv[mod % i] * (mod - mod / i) % mod;
pwrt[0][0] = 1, pwrt[0][1] = Pw(rt, (mod - 1) / tot);
pwrt[1][0] = 1, pwrt[1][1] = Pw(pwrt[0][1], mod - 2);
for (int i = 2; i < tot; ++i) {
pwrt[0][i] = (ll)pwrt[0][i - 1] * pwrt[0][1] % mod;
pwrt[1][i] = (ll)pwrt[1][i - 1] * pwrt[1][1] % mod;
}
}
inline void NTT(int* f, bool ty, int t = tot) {
for (int i = 0; i < t; ++i)
if (i < num[i]) swap(f[i], f[num[i]]);
for (int i = 0; i < t; ++i) q[i] = f[i];
for (int len = 2; len <= t; len <<= 1) {
int gap = len >> 1, d = tot / len;
for (int i = 0, j = 0; i < gap; ++i, j += d) w[i] = pwrt[ty][j];
for (int i = 0, tmp; i < t; i += len)
for (int j = 0; j < gap; ++j) {
tmp = w[j] * q[i + j + gap] % mod;
q[i + j + gap] = q[i + j] + mod - tmp;
q[i + j] += tmp;
}
if (len == (1 << 20)) for (int i = 0; i < t; ++i) q[i] %= mod;
}
for (int i = 0; i < t; ++i) f[i] = q[i] % mod;
if (ty) {
for (int i = 0; i < t; ++i) f[i] = (ll)f[i] * inv[t] % mod;
memset(f + (t >> 1), 0, t << 1);
}
}
int a[__], b[__];
inline void Inv(int *h, int *f, int L) {
memset(a, 0, sizeof a), memset(b, 0, sizeof b);
b[0] = f[0];
for (int len = 2, t = 4; len <= L; len <<= 1, t <<= 1) {
memcpy(a, f, len << 2);
for (int i = 1; i < t; ++i) num[i] = (num[i >> 1] >> 1) | ((i & 1) ? t >> 1 : 0);
NTT(a, 0, t), NTT(b, 0, t);
for (int i = 0; i < t; ++i) b[i] = (ll)b[i] * (2 + mod - (ll)a[i] * b[i] % mod) % mod;
NTT(b, 1, t);
}
for (int i = 0; i < tot; ++i) h[i] = b[i];
}
int gi() {
int x = 0; char c = getchar();
while (!isdigit(c)) c = getchar();
while (isdigit(c)) x = (x << 3) + (x << 1) + c - '0', c = getchar();
return x;
}
int A[_][16 + 1][__];
int t1[__], t2[16 + 1][__], t3[__];
void Init() {
m = gi(), numE = gi(), numQ = gi(), n = gi();
for (int i = 1, x, y, w; i <= numE; ++i) {
x = gi(), y = gi(), w = gi();
++A[x][y][w];
}
Pre();
}
void Run() {
for (int i = 1; i <= m; ++i) A[i][i][0] = mod - 1, A[i][i + m][0] = mod - 1;
for (int i = 1; i <= m; ++i) {
Inv(t1, A[i][i], tot >> 1);
NTT(t1, 0);
for (int k = i; k <= m + i; ++k) {
NTT(A[i][k], 0);
for (int l = 0; l < tot; ++l) A[i][k][l] = (ll)A[i][k][l] * t1[l] % mod;
NTT(A[i][k], 1);
memcpy(t2[k], A[i][k], tot << 2);
NTT(t2[k], 0);
}
for (int j = 1; j <= m; ++j) {
if (j == i) continue;
memcpy(t1, A[j][i], tot << 2);
NTT(t1, 0);
for (int k = i + 1; k <= m + i; ++k) {
for (int l = 0; l < tot; ++l) t3[l] = (ll)t1[l] * t2[k][l] % mod;
NTT(t3, 1);
for (int l = 0; l < tt; ++l) A[j][k][l] = (A[j][k][l] - t3[l] + mod) % mod;
}
}
}
for (int i = 1, x, y, w; i <= numQ; ++i) {
x = gi(), y = gi(), w = gi();
printf("%d\n", A[x][m + y][w]);
}
}
int main() {
Init();
Run();
return 0;
}