CF1697E
令 \(dis(i,j)\) 表示点 \(i,j\) 之间的曼哈顿距离。
如果钦定两个点 \(i,j\) 染同一种颜色,那么就可以发现这些东西:
- 不能存在点 \(k\) 与 \(i,j\) 不同色,并且满足 \(dis(i,k)\lt dis(i,j)\) 或 \(dis(j,k)\lt dis(i,j)\)。
- 对于所有点 \(k\) 满足 \(dis(i,k)=dis(i,j)\) 或 \(dis(j,k)=dis(i,j)\),那么 \(k\) 也要染上 \(i,j\) 的颜色。
这样就可以找到一个可能可以染上同一种颜色的点集,再经过判断之后就可以知道所有可以染上同一种颜色的点集。称这种点集为合法点集。
接着,可以发现合法点集的大小至多为 \(4\)。
证明:任意取两个点,分别以它们为圆心,它们的距离为半径,作曼哈顿圆(就是 \(45\)° 倾斜的正方形),最多只有两个交点。
但是并没有用。
也不难发现合法点集不可能有交。
证明:
- 不妨假设一个点在两个合法点集内,这两个点集对应的相同的曼哈顿距离分别是 \(d_1,d_2\)。
- 如果 \(d_1=d_2\),那么它们应该被染成同样的颜色,发现不合法了。
- 否则如果 \(d_1\ne d_2\),那么对于 \(d\) 较大的点集,它违反了上面的第一个结论。
把每个点向其所有的最近点连接一条有向边,那么若一些点是合法点集,当且仅当它们构成了完全图。
于是对于每个点,找出它能到的所有点,判断这些点两两之间是否都有边,时间复杂度 \(\mathcal O(n^3)\),这样就处理出了所有合法点集。
由于点集大小至多为 \(4\),那么可以直接枚举点集大小分别为 \(2,3,4\) 的数量,不过太麻烦了。
考虑 DP,令 \(siz_i\) 表示第 \(i\) 个合法点集的大小,设总共有 \(m\) 个合法点集。
令 \(f(i,j)\) 表示前 \(i\) 个点集,用了 \(j\) 种颜色的方案数。
则有
\[f(i,j)=[siz_i\gt1](n-j+1)f(i-1,j-1)+[j\ge siz_i]\dbinom{n-j+siz_i}{siz_i}siz_i!f(i-1,j-siz_i)
\]
最终答案是
\[\sum_{i=1}^n f(m,i)
\]
Code:
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define fi first
#define se second
typedef pair <int, int> pii;
typedef long long ll;
const int N = 105, mod = 998244353, inf = 0x3f3f3f3f;
struct mint {
int v = 0;
mint(int _v = 0) : v(_v) {}
mint &operator += (const mint &X) { return (v += X.v) >= mod ? v -= mod : v, *this; }
mint &operator -= (const mint &X) { return (v += mod - X.v) >= mod ? v -= mod : v, *this; }
mint &operator *= (const mint &X) { return v = 1ll * v * X.v % mod, *this; }
mint &operator /= (const mint &X) { return *this *= X.inv(); }
mint qpow(int y) const { mint res = 1, x = *this; while (y) { if (y & 1) res *= x; x *= x; y >>= 1; } return res; }
mint inv() const { return qpow(mod - 2); }
friend mint operator + (const mint &A, const mint &B) { return mint(A) += B; }
friend mint operator - (const mint &A, const mint &B) { return mint(A) -= B; }
friend mint operator * (const mint &A, const mint &B) { return mint(A) *= B; }
friend mint operator / (const mint &A, const mint &B) { return mint(A) /= B; }
};
int n;
pii a[N];
int dis[N][N], mn[N];
mint fac[N], inv[N];
vector <int> G[N], tmp;
int vis[N], col;
int siz[N], tot;
mint f[N];
void init(int n) {
fac[0] = 1;
for (int i = 1; i <= n; ++i) fac[i] = fac[i - 1] * i;
inv[n] = fac[n].inv();
for (int i = n - 1; ~i; --i) inv[i] = inv[i + 1] * (i + 1);
}
mint C(int n, int m) {
if (n < 0 || m < 0 || n < m) return 0;
return fac[n] * inv[n - m] * inv[m];
}
int dist(int x, int y) {
return abs(a[x].fi - a[y].fi) + abs(a[x].se - a[y].se);
}
bool dfs(int u) {
vis[u] = col, tmp.pb(u);
bool f = 1;
for (int v : G[u]) {
if (!vis[v]) f &= dfs(v);
else if (vis[u] != vis[v]) f = 0;
}
return f;
}
int main() {
scanf("%d", &n);
init(n);
for (int i = 1; i <= n; ++i) scanf("%d%d", &a[i].fi, &a[i].se);
for (int i = 1; i <= n; ++i)
for (int j = i + 1; j <= n; ++j)
dis[i][j] = dis[j][i] = dist(i, j);
for (int i = 1; i <= n; ++i) {
mn[i] = inf;
for (int j = 1; j <= n; ++j) if (i != j)
mn[i] = min(mn[i], dis[i][j]);
}
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= n; ++j) if (i != j && dis[i][j] == mn[i])
G[i].pb(j);
}
for (int i = 1; i <= n; ++i) if (!vis[i]) {
tmp.clear(), ++col; bool flag = dfs(i);
for (int x : tmp) for (int y : tmp) if (x != y && (dis[x][y] != mn[x] || dis[x][y] != mn[y])) flag = 0;
if (!flag) {
for (int x : tmp) vis[x] = 0;
vis[i] = 1, siz[++tot] = 1;
}
else siz[++tot] = tmp.size();
}
f[0] = 1;
for (int i = 1; i <= tot; ++i, f[0] = 0) {
for (int j = n; j; --j) {
f[j] = 0;
if (siz[i] > 1) f[j] += f[j - 1] * (n - j + 1);
if (j >= siz[i]) f[j] += f[j - siz[i]] * C(n - j + siz[i], siz[i]) * fac[siz[i]];
}
}
mint ans = 0;
for (int i = 1; i <= n; ++i) ans += f[i];
printf("%d", ans);
return 0;
}