[BZOJ3270]博物馆(矩阵求逆)

题面

https://darkbzoj.tk/problem/3270

题解

前置知识:

本题的正常做法见https://blog.csdn.net/aarongzk/article/details/51473258。

这里提供一种不正常的矩阵求逆方法。之所以说它不正常,是因为我自己都不能证明这个算法的正确性,但是它是能过题的(

设状态\((t,u,v)\)表示当前时刻为t,两人分别处在位置u、v。设\(f[t][u][v]\)表示从初始状态出发,能够到达状态\((t,u,v)\)的概率。下记\(q[u]=\frac{1-p[u]}{deg[u]}\)\(link[u]\)表示u的相邻点集合,有转移方程如下:

\[f[t][u][v] = \begin{cases} f[t-1][u][v]\times p[u]p[v] \\ \sum\limits_{i \in link[u],i\neq v}f[t-1][i][v]\times q[i]p[v] \\ \sum\limits_{j \in link[v],j \neq u}f[t-1][u][j]\times p[u]q[j] \\ \sum\limits_{i \in link[u],j \in link[v],i \neq j}f[t-1][i][j]\times q[i]q[j] \end{cases} \]

发现这个转移可以用矩阵描述。

\[M \times \left[ \begin{matrix} f[t-1][1][1] \\ f[t-1][1][2] \\ \vdots \\f[t-1][n][n] \end{matrix} \right] = \left[ \begin{matrix} f[t][1][1] \\ f[t][1][2] \\ \vdots \\f[t][n][n] \end{matrix} \right] \]

其中转移矩阵M,是一个\(n^2 \times n^2\)的矩阵,可以通过上面的转移方程计算出来。

我们想计算出\(F[u][v]=\sum_{t=0}^{+ \infty}f[t][u][v]\),因为在u点相遇的答案就是\(\frac{F[u][u]}{\sum_i f[i][i]}\)。而

\[(E+M+M^2+…) \times \left[ \begin{matrix} f[0][1][1] \\ f[0][1][2] \\ \vdots \\f[0][n][n] \end{matrix} \right] = \left[ \begin{matrix} F[1][1] \\ F[1][2] \\ \vdots \\F[n][n] \end{matrix} \right] \]

显然\(f[0][u][v] = [u=a]\times[v=b]\);另外,\(E+M+M^2+…\)又可以化成\(\frac{E}{E-M}=(E-M)^{-1}\)

这样就做完了,总时间\(O(n^6)\)

不过,这个算法有两个BUG:

  1. 无法证明\(E-M\)的行列式不等于0,因此\((E-M)^{-1}\)不一定存在
  2. \(E+M+M^2…\)的收敛性无从证明,因此几何级数公式不一定适用

虽然这样做能过题,但是这两点仍未解决,欢迎在评论区发表见解~

代码

#include<bits/stdc++.h>

using namespace std;

#define ld long double
#define rg register
#define In inline

const int SN = 400;
const int N = 20;
const ld eps = 1e-9;

In int sgn(ld x){
	return x < eps ? -1 : x > eps;
}

In int read(){
	int s = 0,ww = 1;
	char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-')ww = -1;ch = getchar();}
	while('0' <= ch && ch <= '9'){s = 10 * s + ch - '0';ch = getchar();}
	return s * ww;
}

int n,m,sn;

In int id(int i,int j){
	return (i - 1) * n + j;
}

struct mat{
	ld d[SN+5][SN+5];
	friend mat operator + (mat a,mat b){
		for(rg int i = 1;i <= sn;i++)
			for(rg int j = 1;j <= sn;j++)a.d[i][j] += b.d[i][j];
		return a;
	}	
	friend mat operator - (mat a,mat b){	
		for(rg int i = 1;i <= sn;i++)
			for(rg int j = 1;j <= sn;j++)a.d[i][j] -= b.d[i][j];
		return a;
	}
	friend mat operator * (mat a,mat b){
		mat c;
		for(rg int i = 1;i <= sn;i++)
			for(rg int j = 1;j <= sn;j++)
				for(rg int k = 1;k <= sn;k++)c.d[i][j] += a.d[i][k] * b.d[k][j];
		return c;
	}
	void Swap(int i1,int i2){
		for(rg int j = 1;j <= sn;j++)swap(d[i1][j],d[i2][j]);
	}
	void mul(int i,ld k){
		for(rg int j = 1;j <= sn;j++)d[i][j] *= k;
	}
	void add(int i1,int i2,ld k){
		for(rg int j = 1;j <= sn;j++)d[i1][j] += d[i2][j] * k;
	}
}M,E;

mat inv(mat a){
	mat b = E;
	for(rg int j = 1;j <= sn;j++){
		for(rg int i = j;i <= sn;i++)if(sgn(a.d[i][j]) != 0){
			if(i != j)a.Swap(i,j),b.Swap(i,j);
			break;
		}
		ld x = 1.0 / a.d[j][j];
		a.mul(j,x),b.mul(j,x);
		for(rg int i = j + 1;i <= sn;i++){
			ld y = a.d[i][j];
			a.add(i,j,-y),b.add(i,j,-y);
		}
	}
	for(rg int j = sn;j >= 1;j--){
		for(rg int i = 1;i < j;i++){
			ld y = a.d[i][j];
			a.add(i,j,-y),b.add(i,j,-y);
		}
	}
	return b;	
}

int head[N+5],deg[N+5],cnt;

struct edge{
	int next,des;
}e[2*SN+5];

In void addedge(int a,int b){
	cnt++;
	deg[a]++;
	e[cnt].des = b;
	e[cnt].next = head[a];
	head[a] = cnt;
}

ld p[N+5],q[N+5];

int main(){
	int a,b;
	scanf("%d%d%d%d",&n,&m,&a,&b);
	sn = n * n;	
	for(rg int i = 1;i <= sn;i++)E.d[i][i] = 1;
	for(rg int i = 1;i <= m;i++){
		int u,v;
		scanf("%d%d",&u,&v);
		addedge(u,v);
		addedge(v,u);
	}
	for(rg int i = 1;i <= n;i++){
		double x;
		scanf("%lf",&x);
		p[i] = x;
		q[i] = (1.0 - p[i]) / deg[i];
	}
	for(rg int u = 1;u <= n;u++)
		for(rg int v = 1;v <= n;v++){
			if(u != v)M.d[id(u,v)][id(u,v)] = p[u] * p[v];
			for(rg int eu = head[u];eu;eu = e[eu].next){
				int i = e[eu].des;
				if(i != v)M.d[id(u,v)][id(i,v)] = q[i] * p[v];
			}
			for(rg int ev = head[v];ev;ev = e[ev].next){
				int j = e[ev].des;
				if(u != j)M.d[id(u,v)][id(u,j)] = p[u] * q[j];
			}
			for(rg int eu = head[u];eu;eu = e[eu].next){
				for(rg int ev = head[v];ev;ev = e[ev].next){
					int i = e[eu].des,j = e[ev].des;
					if(i != j)M.d[id(u,v)][id(i,j)] = q[i] * q[j];
				}
			}
		}
	M = inv(E - M);
	ld s = 0;
	for(rg int i = 1;i <= n;i++)s += M.d[id(i,i)][id(a,b)];
	for(rg int i = 1;i <= n;i++){
		double p = M.d[id(i,i)][id(a,b)] / s;
		printf("%.6lf",p);
		putchar(i == n ? '\n' : ' ');
	}
	return 0;
}
posted @ 2020-10-05 19:52  coder66  阅读(119)  评论(0编辑  收藏  举报