线性代数+图论好题。

题目链接: (bzoj) https://www.lydsy.com/JudgeOnline/problem.php?id=3168

(luogu) https://www.luogu.org/problemnew/show/P4100

题解: 首先\(A\)矩阵必须满秩。有一个结论是,设矩阵\(C\)满足\(CA=B\), 则\(A\)的第\(i\)行可以被\(B\)的第\(j\)行来替代当且仅当\(C_{j,i}\ne 0\).

\(B_j\)可以用\(A\)除了\(i\)之外的行向量线性表示,那么\(B_j\)无法替换\(A_i\). 若\(C_{j,i}=0\)代表用\(A\)矩阵的行向量表示\(B_j\)的系数向量中\(A_i\)这一项的系数为\(0\).

那么\(CA=B\)可以推出\(CAA^{-1}=BA^{-1}, C=BA^{-1}\)

数学被各种吊打啊……

然后我们就在\(O(n^3)\)时间内求出了对于每一个\(i,j\), \(A\)中第\(i\)行是否能被\(B\)中第\(j\)行替换

问题转化成了,给一张二分图,保证有完美匹配,求一个完美匹配使得\(A\)中的每个点在\(B\)中的匹配点构成的排列的字典序最小。

这个东西,貌似必须用匈牙利算法。。先跑一边随便求出一个完美匹配,然后从\(1\)号到\(n\)号每个点再匹配尽量小的,如果能找到使当前\(i\)号点更小,且不影响\(i\)号点前面的点的交错路,那么就可以更新答案了。

时间复杂度\(O(n^3)\).

代码

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cassert>
#include<iostream>
#define llong long long
using namespace std;

inline int read()
{
	int x=0; bool f=1; char c=getchar();
	for(;!isdigit(c);c=getchar()) if(c=='-') f=0;
	for(; isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+(c^'0');
	if(f) return x;
	return -x;
}

const int N = 300;
const int P = 942030731;
llong quickpow(llong x,llong y)
{
	llong cur = x,ret = 1ll;
	for(int i=0; y; i++)
	{
		if(y&(1ll<<i)) {y-=(1ll<<i); ret = ret*cur%P;}
		cur = cur*cur%P;
	}
	return ret;
}
llong mulinv(llong x) {return quickpow(x,P-2);}
struct Matrix
{
	llong a[N+3][N+3]; int n;
	Matrix() {}
	Matrix(int _n) {n = _n; for(int i=1; i<=n; i++) for(int j=1; j<=n; j++) a[i][j] = 0ll;}
	void read(int _n)
	{
		n = _n;
		for(int i=1; i<=n; i++) for(int j=1; j<=n; j++) scanf("%lld",&a[i][j]);
	}
	void write()
	{
		printf("%d\n",n);
		for(int i=1; i<=n; i++) {for(int j=1; j<=n; j++) printf("%lld ",a[i][j]); puts("");}
	}
	Matrix operator *(const Matrix &arg)
	{
		Matrix ret = Matrix(n);
		for(int i=1; i<=n; i++)
		{
			for(int k=1; k<=n; k++)
			{
				for(int j=1; j<=n; j++)
				{
					ret.a[i][j] = (ret.a[i][j]+a[i][k]*arg.a[k][j])%P;
				}
			}
		}
		return ret;
	}
	Matrix inv()
	{
		Matrix ret = Matrix(n); for(int i=1; i<=n; i++) ret.a[i][i] = 1ll;
		for(int i=1; i<=n; i++)
		{
			if(a[i][i]==0)
			{
				bool found = false;
				for(int j=i+1; j<=n; j++)
				{
					if(a[j][i])
					{
						for(int k=1; k<=n; k++) {swap(a[i][k],a[j][k]),swap(ret.a[i][k],ret.a[j][k]);}
						found = true; break;
					}
				}
				if(found==false) {ret.a[0][0] = P; return ret;}
			}
			for(int j=i+1; j<=n; j++)
			{
				llong coe = (P-a[j][i]*mulinv(a[i][i])%P)%P;
				for(int k=1; k<=n; k++)
				{
					a[j][k] = (a[j][k]+coe*a[i][k])%P;
					ret.a[j][k] = (ret.a[j][k]+coe*ret.a[i][k])%P;
				}
			}
		}
		for(int i=1; i<=n; i++)
		{
			llong coe = mulinv(a[i][i]);
			for(int j=1; j<=n; j++) {a[i][j] = a[i][j]*coe%P; ret.a[i][j] = ret.a[i][j]*coe%P;}
		}
//		write();
//		ret.write();
		for(int i=n; i>=1; i--)
		{
			for(int j=n; j>=i+1; j--)
			{
				llong coe = (P-a[i][j]*mulinv(a[j][j])%P)%P;
				a[i][j] = 0ll;
				for(int k=1; k<=n; k++) ret.a[i][k] = (ret.a[i][k]+coe*ret.a[j][k])%P;
			}
		}
		return ret;
	}
} a,b,aux,c;
int g[N+3][N+3];
int vis[N+3];
int match1[N+3],match2[N+3];
int n;

bool dfs1(int u)
{
	for(int i=1; i<=n; i++)
	{
		if(g[i][u]==true && vis[i]==false)
		{
			vis[i] = true;
			if(match2[i]==0 || dfs1(match2[i])==true)
			{
				match2[i] = u; match1[u] = i;
				return true;
			}
		}
	}
	return false;
}

bool dfs2(int u,int u0)
{
	for(int i=1; i<=n; i++)
	{
		if(g[i][u]==true && vis[i]==false)
		{
			vis[i] = true;
			if(match2[i]==u0 || (match2[i]>u0 && dfs2(match2[i],u0)==true))
			{
				match2[i] = u; match1[u] = i;
				return true;
			}
		}
	}
	return false;
}

int main()
{
	scanf("%d",&n);
	a.read(n); b.read(n);
	aux = a.inv();
	if(aux.a[0][0]==P) {printf("NIE"); return 0;}
	c = b*aux;
	for(int i=1; i<=n; i++)
	{
		for(int j=1; j<=n; j++) g[i][j] = c.a[i][j]==0 ? 0 : 1;
	}
//	for(int i=1; i<=n; i++) {for(int j=1; j<=n; j++) printf("%d",g[i][j]); puts("");}
	int mf = 0;
	for(int i=1; i<=n; i++)
	{
		for(int j=1; j<=n; j++) vis[j] = false;
		mf += dfs1(i);
	}
	if(mf<n) {printf("NIE"); return 0;}
	printf("TAK\n");
	for(int i=1; i<=n; i++)
	{
		for(int j=1; j<=n; j++) vis[j] = false;
		dfs2(i,i);
	}
	for(int i=1; i<=n; i++) printf("%d\n",match1[i]);
	return 0;
}