【题解】P4100 [HEOI2013]钙铁锌硒维生素
思路大致来自仙人学长 cyffff.
数学不好,大概有错,大家快逃。
思路
高斯消元 + 匈牙利。
首先需要一些简单的线代知识。
-
线性组合:如果向量 \(B = \sum\limits_{i = 1}^n k_i A_i\),则称 \(B\) 是向量集 \(A\) 的线性组合。
-
线性有关 / 线性无关:如果一组向量中的任意一个向量都不是其余 \(n - 1\) 个向量的线性组合,则称这组向量线性无关,反之为线性有关。
-
定理:如果 \(n\) 维向量集线性无关,则所有 \(n\) 维向量都可以表示成该向量集的线性组合。
题意转化过来就是给定两个线性无关的 \(n\) 维向量集 \(A, B\),求一个字典序最小的排列 \(p\),使得将任意 \(A_i\) 替换为 \(B_{p_i}\) 后,得到的向量集线性无关。
假设 \(A_i\) 可以被 \(B_j\) 等价替换,那么 \(B_j\) 不能是 \(A_1, \cdots, A_{i - 1}, A_{i + 1}, \cdots, A_n\) 的线性组合,不然得到的向量集线性有关。
考虑令 \(R_{i, j}\) 表示当 \(B_i\) 用 \(A\) 的线性组合表示时 \(A_j\) 的系数,显然当 \(R_{i, j} = 0\) 时 \(B_i\) 是 \(A_1, \cdots, A_{j - 1}, A_{j + 1}, \cdots, A_n\) 的线性组合,也就是 \(B_i\) 不能被 \(A_j\) 替换。
于是 \(B_i\) 可以被 \(A_j\) 替换的充要条件是 \(R_{i, j} \neq 0\).
注意到 \(B_{i, j} = \sum\limits_{k = 1}^n R_{i, k} A_{k, j}\),将 \(R\) 也看作矩阵,则 \(B = R A\),也就是 \(R = B A^{-1}\).
注意到把 \(A\) 消成单位矩阵 \(I\) 等价于对其乘以 \(A^{-1}\),那么只需要在高斯消元的时候对 \(B\) 进行同样的操作就可以求出 \(A^{-1} B\).
- 结论:注意高斯消元是初等行变换,等价于左乘逆元。
但是要求的是右乘,考虑对 \(A, B, R\) 同时进行转置或者把高斯消元换成初等列变换。
于是 \(B_i\) 可以被 \(A_j\) 替换的充要条件是 \(R_{j, i} \neq 0\),即 \(R_{i, j}\) 实际上表示 \(A_i\) 是否能替换 \(B_j\).
那么问题变成给定一张二分图,求一个字典序最小的匹配。
考虑先求出任意一个匹配再调整成最优解。
考虑每个左部点 \(u\) 匹配的右部点 \(v\),如果原先的匹配中存在一条包含 \((u, v)\) 的增广路(环),那么可以把 \(u\) 匹配的结点换成 \(v\).
于是升序对每个结点进行贪心就行,时间复杂度同朴素匈牙利。
代码
#include <cstdio>
#include <cmath>
#include <cstring>
#include <iostream>
using namespace std;
typedef double db;
const int maxn = 305;
const db eps = 1e-8;
int n;
int mch[maxn];
bool vis[maxn];
db a[maxn][maxn], b[maxn][maxn];
bool dfs1(int u)
{
for (int v = 1; v <= n; v++)
if ((fabs(b[u][v]) > eps) && (!vis[v]))
{
vis[v] = true;
if ((!mch[v]) || (dfs1(mch[v]))) return mch[v] = u, true;
}
return false;
}
int dfs2(int u, int fr)
{
for (int v = 1; v <= n; v++)
if ((fabs(b[u][v]) > eps) && (!vis[v]))
{
vis[v] = true;
if ((mch[v] == fr) || ((mch[v] > fr) && dfs2(mch[v], fr))) return mch[v] = u, v;
}
return 0;
}
int main()
{
scanf("%d", &n);
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++)
scanf("%lf", &a[j][i]);
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++)
scanf("%lf", &b[j][i]);
for (int i = 1; i <= n; i++)
{
int k = i;
for (int j = i + 1; j <= n; j++)
if (fabs(a[j][i]) > fabs(a[k][i])) k = j;
if (fabs(a[k][i]) < eps) return puts("NIE"), 0;
if (i != k) swap(a[i], a[k]), swap(b[i], b[k]);
db x = a[i][i];
for (int j = 1; j <= n; j++) a[i][j] /= x, b[i][j] /= x;
for (int j = 1; j <= n; j++)
{
if (i != j)
{
x = a[j][i];
for (int k = 1; k <= n; k++) a[j][k] -= a[i][k] * x, b[j][k] -= b[i][k] * x;
}
}
}
for (int i = 1; i <= n; i++)
{
memset(vis, false, (n + 1) * sizeof(bool));
if (!dfs1(i)) return puts("NIE"), 0;
}
puts("TAK");
for (int i = 1; i <= n; i++)
{
memset(vis, false, (n + 1) * sizeof(bool));
printf("%d\n", dfs2(i, i));
}
return 0;
}