[解题报告]【UNR #4】校园闲逛

传送🚪

题意

\(n\) 个点 \(m\) 条边的带权有向图, \(Q\) 次询问, 求出 \(x\)\(y\) 之间长度为 \(v\) 的路径条数.

数据范围

\(n \le 8,\ m \le 300000,\ max_v \le 65000,\ Q \le 1000\).

图的边权大于等于 1.

思路

路径长度为 \(v\) 显然是一个背包的形式, 可以用生成函数搞.

图的限制可以考虑用矩阵解决.

那么我们设两个元素为多项式的矩阵.

矩阵 \(F\), 其中的元素为 \(i\)\(j\) 的长度为 \(v\)路径条数的生成函数.

矩阵 \(G\), 其中的元素为 \(i\)\(j\) 的长度为 \(v\)边数的生成函数.

那么可以得到

\[F = GF + I \]

(\(I\) 为单位矩阵, 表示待在原地不走的情况.)

从而得到

\[F = \frac{I}{I - G} \]

因为题目保证图的边权大于等于 1, 所以 G 中的元素的零次项系数都为 0, 所以可以直接对 \(I - G\) 求逆.

矩阵求逆的时候我们可以利用点值优化多项式的运算 (需要注意的是 FFT 是循环卷积). 总共需要进行 \(O(n)\) 次多项式求逆, \(O(n^2)\) 次 DFT 和 IDFT, \(O(n^3)\) 次多项式加减法, 总时间复杂度为 \(O(n v\log v + n^2 v \log v + n^3 v)\). 具体实现可以看代码.

后序

由于上述做法常数比较大, 所以需要各种多项式卡常技巧. 比如预处理单位根, 用 unsigned long long 减少取模次数等等.

建议不要封装结构体, 写丑了会比较慢.

貌似还有一种多项式套矩阵的做法, 常数会更小, 改天再搞吧 (gugugu).

代码

#include <cstdio>
#include <cstring>
#include <ctime>
#include <iostream>

using namespace std;

typedef long long ll;

const int _ = 8 + 1;
const int __ = (1 << 17) + 1;
const int mod = 998244353, rt = 3;

int n, m, numE, numQ, tot, tt;

int Pw(int a, int p) {
  int res = 1;
  while (p) {
    if (p & 1) res = (ll)res * a % mod;
    a = (ll)a * a % mod;
    p >>= 1;
  }
  return res;
}

int num[__], pwrt[2][__], inv[__], w[__];
unsigned long long q[__];

void Pre() {
  tot = 1; while (tot <= n + n) tot <<= 1;
  tt = (tot >> 1), inv[1] = 1;
  for (int i = 2; i <= tot; ++i) inv[i] = (ll)inv[mod % i] * (mod - mod / i) % mod;
  pwrt[0][0] = 1, pwrt[0][1] = Pw(rt, (mod - 1) / tot);
  pwrt[1][0] = 1, pwrt[1][1] = Pw(pwrt[0][1], mod - 2);
  for (int i = 2; i < tot; ++i) {
    pwrt[0][i] = (ll)pwrt[0][i - 1] * pwrt[0][1] % mod;
    pwrt[1][i] = (ll)pwrt[1][i - 1] * pwrt[1][1] % mod;
  }
}

inline void NTT(int* f, bool ty, int t = tot) {
  for (int i = 0; i < t; ++i)
    if (i < num[i]) swap(f[i], f[num[i]]);
  for (int i = 0; i < t; ++i) q[i] = f[i];
  for (int len = 2; len <= t; len <<= 1) {
    int gap = len >> 1, d = tot / len;
    for (int i = 0, j = 0; i < gap; ++i, j += d) w[i] = pwrt[ty][j];
    for (int i = 0, tmp; i < t; i += len)
      for (int j = 0; j < gap; ++j) {
        tmp = w[j] * q[i + j + gap] % mod;
        q[i + j + gap] = q[i + j] + mod - tmp;
        q[i + j] += tmp;
      }
    if (len == (1 << 20)) for (int i = 0; i < t; ++i) q[i] %= mod;
  }
  for (int i = 0; i < t; ++i) f[i] = q[i] % mod;
  if (ty) {
    for (int i = 0; i < t; ++i) f[i] = (ll)f[i] * inv[t] % mod;
    memset(f + (t >> 1), 0, t << 1);
  }
}

int a[__], b[__];

inline void Inv(int *h, int *f, int L) {
  memset(a, 0, sizeof a), memset(b, 0, sizeof b);
  b[0] = f[0];
  for (int len = 2, t = 4; len <= L; len <<= 1, t <<= 1) {
    memcpy(a, f, len << 2);
    for (int i = 1; i < t; ++i) num[i] = (num[i >> 1] >> 1) | ((i & 1) ? t >> 1 : 0);
    NTT(a, 0, t), NTT(b, 0, t);
    for (int i = 0; i < t; ++i) b[i] = (ll)b[i] * (2 + mod - (ll)a[i] * b[i] % mod) % mod;
    NTT(b, 1, t);
  }
  for (int i = 0; i < tot; ++i) h[i] = b[i];
}

int gi() {
  int x = 0; char c = getchar();
  while (!isdigit(c)) c = getchar();
  while (isdigit(c)) x = (x << 3) + (x << 1) + c - '0', c = getchar();
  return x;
}

int A[_][16 + 1][__];
int t1[__], t2[16 + 1][__], t3[__];

void Init() {
  m = gi(), numE = gi(), numQ = gi(), n = gi();
  for (int i = 1, x, y, w; i <= numE; ++i) {
    x = gi(), y = gi(), w = gi();
    ++A[x][y][w];
  }
  Pre();
}

void Run() {
  for (int i = 1; i <= m; ++i) A[i][i][0] = mod - 1, A[i][i + m][0] = mod - 1;
  for (int i = 1; i <= m; ++i) {
    Inv(t1, A[i][i], tot >> 1);
    NTT(t1, 0);
    for (int k = i; k <= m + i; ++k) {
      NTT(A[i][k], 0);
      for (int l = 0; l < tot; ++l) A[i][k][l] = (ll)A[i][k][l] * t1[l] % mod;
      NTT(A[i][k], 1);
      memcpy(t2[k], A[i][k], tot << 2);
      NTT(t2[k], 0);
    }
    for (int j = 1; j <= m; ++j) {
      if (j == i) continue;
      memcpy(t1, A[j][i], tot << 2);
      NTT(t1, 0);
      for (int k = i + 1; k <= m + i; ++k) {
        for (int l = 0; l < tot; ++l) t3[l] = (ll)t1[l] * t2[k][l] % mod;
        NTT(t3, 1);
        for (int l = 0; l < tt; ++l) A[j][k][l] = (A[j][k][l] - t3[l] + mod) % mod;
      }
    }
  }
  for (int i = 1, x, y, w; i <= numQ; ++i) {
    x = gi(), y = gi(), w = gi();
    printf("%d\n", A[x][m + y][w]);
  }
}

int main() {
  Init();
  Run();
  return 0;
}
posted @ 2020-11-02 10:26  BruceW  阅读(165)  评论(1编辑  收藏  举报