简单计数 题解
最近在搞线代,拿了一道以前zzq的模拟题做了一下
前置技能:best定理
一个有向图的欧拉回路个数等于内向树个数乘上所有(deg[i]-1)!的乘积
我们考虑拆开算贡献,考虑经过(u,v),(v,w)的欧拉回路个数,我们把(u,v),(v,w)断开,连上(u,w),如果v是孤立点,特判一下,否则我们发现我们改的矩阵的位置只会是4个,其中3个和v相关,直接删除掉即可,剩下的那个位置考虑拉普拉斯展开,所以我们要求的本质就是对于所有i,去掉第i行第i列后的矩阵的逆矩阵,分治消元即可
//waz #include <bits/stdc++.h> using namespace std; #define mp make_pair #define pb push_back #define fi first #define se second #define ALL(x) (x).begin(), (x).end() #define SZ(x) ((int)((x).size())) typedef pair<int, int> PII; typedef vector<int> VI; typedef long long int64; typedef unsigned int uint; typedef unsigned long long uint64; #define gi(x) ((x) = F()) #define gii(x, y) (gi(x), gi(y)) #define giii(x, y, z) (gii(x, y), gi(z)) int F() { char ch; int x, a; while (ch = getchar(), (ch < '0' || ch > '9') && ch != '-'); if (ch == '-') ch = getchar(), a = -1; else a = 1; x = ch - '0'; while (ch = getchar(), ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + ch - '0'; return a * x; } const int mod = 998244353; int inc(int a, int b) { a += b; return a >= mod ? a - mod : a; } int dec(int a, int b) { a -= b; return a < 0 ? a + mod : a; } int fpow(int a, int x) { int ret = 1; for (; x; x >>= 1) { if (x & 1) ret = 1LL * ret * a % mod; a = 1LL * a * a % mod; } return ret; } const int N = 310; struct matrix { int n; int a[N][N]; matrix() { memset(a, 0, sizeof a); } int *operator[](int pos) { return a[pos]; } void swap(int i, int j) { return std::swap(a[i], a[j]); } void rev() { for (int i = 1; i <= n; ++i) for (int j = i + 1; j <= n; ++j) std::swap(a[i][j], a[j][i]); } friend matrix operator * (matrix x, int y) { for (int i = 1; i <= x.n; ++i) for (int j = 1; j <= x.n; ++j) x[i][j] = 1LL * x[i][j] * y % mod; return x; } matrix del(int t) { matrix c = *this; for (int i = 1; i <= c.n; ++i) for (int j = 1; j <= c.n; ++j) { int x = i, y = j; if (i >= t) ++x; if (j >= t) ++y; c[i][j] = c[x][y]; } c.n--; return c; } void out() { printf("debug : %d\n", n); for (int i = 1; i <= n; ++i) for (int j = 1; j <= n; ++j) printf("%d%c", a[i][j], ",\n"[j == n]); } } c[N], o; int det(matrix a) { int ans = 1; for (int i = 1; i <= a.n; ++i) { if (!a[i][i]) { for (int j = i + 1; j <= a.n; ++j) if (a[j][i]) { a.swap(i, j); ans = dec(mod, ans); break; } } if (!a[i][i]) return 0; ans = 1LL * ans * a[i][i] % mod; int v = fpow(a[i][i], mod - 2); for (int j = i; j <= a.n; ++j) a[i][j] = 1LL * a[i][j] * v % mod; for (int j = i + 1; j <= a.n; ++j) { int t = a[j][i]; for (int k = i; k <= a.n; ++k) a[j][k] = dec(a[j][k], 1LL * t * a[i][k] % mod); } } return ans; } matrix inv(matrix a) { matrix b; b.n = a.n; for (int i = 1; i <= b.n; ++i) b[i][i] = 1; for (int i = 1; i <= a.n; ++i) { if (!a[i][i]) { for (int j = i + 1; j <= a.n; ++j) if (a[j][i]) { a.swap(i, j); b.swap(i, j); break; } } int v = fpow(a[i][i], mod - 2); for (int j = 1; j <= a.n; ++j) a[i][j] = 1LL * a[i][j] * v % mod, b[i][j] = 1LL * b[i][j] * v % mod; for (int j = 1; j <= a.n; ++j) { if (j == i) continue; int t = a[j][i]; for (int k = 1; k <= a.n; ++k) a[j][k] = dec(a[j][k], 1LL * t * a[i][k] % mod), b[j][k] = dec(b[j][k], 1LL * t * b[i][k] % mod); } } return b; } void guess(matrix &a, matrix &b, int l, int r) { for (int i = l; i <= r; ++i) { if (!a[i][i]) { for (int j = 1; j <= a.n; ++j) if (a[j][i]) { a.swap(i, j); b.swap(i, j); break; } } int v = fpow(a[i][i], mod - 2); for (int j = 1; j <= a.n; ++j) a[i][j] = 1LL * a[i][j] * v % mod, b[i][j] = 1LL * b[i][j] * v % mod; for (int j = 1; j <= a.n; ++j) { if (j == i) continue; int t = a[j][i]; for (int k = 1; k <= a.n; ++k) a[j][k] = dec(a[j][k], 1LL * t * a[i][k] % mod), b[j][k] = dec(b[j][k], 1LL * t * b[i][k] % mod); } } } int n, m, x[N * N], u[N * N], v[N * N]; int w[N][N]; int g[N]; int deg[N]; int fac[N], rfac[N]; void fz(int l, int r, pair<matrix, matrix> now) { if (l == r) { c[l] = now.se.del(l); c[l] = c[l] * g[l]; c[l].rev(); return; } int mid = (l + r) >> 1; pair<matrix, matrix> t = now; guess(t.fi, t.se, l, mid); fz(mid + 1, r, t); guess(now.fi, now.se, mid + 1, r); fz(l, mid, now); } int main() { freopen("count.in", "r", stdin); freopen("count.out", "w", stdout); gii(n, m); o.n = n; fac[0] = 1; for (int i = 1; i <= n; ++i) fac[i] = 1LL * fac[i - 1] * i % mod; rfac[n] = fpow(fac[n], mod - 2); for (int i = n; i; --i) rfac[i - 1] = 1LL * rfac[i] * i % mod; for (int i = 1; i <= m; ++i) giii(x[i], u[i], v[i]), w[u[i]][v[i]] = x[i], ++deg[u[i]], o[u[i]][u[i]] = inc(o[u[i]][u[i]], 1), o[u[i]][v[i]] = dec(o[u[i]][v[i]], 1); c[0] = o.del(1); g[0] = det(c[0]); for (int i = 1; i <= n; ++i) g[i] = g[i - 1]; matrix I; I.n = n; for (int i = 1; i <= n; ++i) I[i][i] = 1; fz(1, n, mp(o, I)); //for (int i = 1; i <= n; ++i) c[i] = o.del(i); //for (int i = 1; i <= n; ++i) g[i] = det(c[i]), c[i] = inv(c[i]), c[i] = c[i] * g[i], c[i].rev(); int mul = 1, ans = 0; for (int i = 1; i <= n; ++i) mul = 1LL * mul * fac[deg[i] - 1] % mod; //c[3].out(); for (int mid = 1; mid <= n; ++mid) { for (int from = 1; from <= n; ++from) if (w[from][mid]) { for (int to = 1; to <= n; ++to) if (w[mid][to] == w[from][mid]) { int i = from > mid ? from - 1 : from; int j = to > mid ? to - 1 : to; int ret = g[mid]; if (deg[mid] == 1) { ans = (ans + 1LL * mul * ret) % mod; //cerr << from << ", " << mid << ", " << to << ", " << ret << endl; } else { ret = dec(ret, c[mid][i][j]); int gg = mul; gg = 1LL * gg * rfac[deg[mid] - 1] % mod; gg = 1LL * gg * fac[deg[mid] - 2] % mod; ans = (ans + 1LL * gg * ret) % mod; //cerr << from << ", " << mid << ", " << to << ", " << ret << endl; } } } } printf("%d\n", ans); }