zjoj1706: [usaco2007 Nov]relays 奶牛接力跑

矩阵乘法(快速幂)

为说明方便,这里让\(k\)为点数,\(n\)为路径长度。

先将点都离散化,这样最后的点只有\(2k\)个。

先考虑一种暴力,每次用\(O(k^3)\)的复杂度来暴力更新,设当前长度\(l\)点的两两最短路矩阵为\(S\),现在要转移到\(l+1\)时的最短路矩阵\(T\)。我们考虑用每条边更新,对于某条从\(x\)连向\(y\)的长度为\(z\)的边,对于任一点\(i\),有:

\[T[i][y]=min(T[i][y],T[i][x]+z) \]

另外,每次更新时,\(T\)矩阵的初始值为无限大。

然后我们就可以用\(O(nk^3)\)的复杂度去做这道题了。但这明显不行。

我们设没有直接连通的两个点距离为无限大,构建出邻接矩阵\(D\),就可以魔改一下上面的式子,改成:

\[T[i][j]=min(T[i][x]+D[x][j]) \]

其中\(x\)为自己枚举的中间节点,然后就出现的如下的代码:

for(int i=0;i<k;++i){
    for(int j=0;j<k;++j){
        for(int l=0;l<k;++l){
            ret.a[i][j]=min(ret.a[i][j],a.a[i][l]+b.a[l][j]);
		}
	}
}

发现,这不是就是矩阵乘法吗?

因为取最小值满足可加性,所以使用矩阵快速幂是可行的。这样,我们就能把复杂度优化为\(O(lognk^3)\)

然后,我就不开O2过不了了。

我们发现从源点能到达的点数最多只有\(k+1\)(因为即使走过每条边都发现一个新节点,也只能发现这么多点。)所以我们可以只用源点能到的点进行离散化,可以将点数从\(2k\)\(k\),从而在矩阵乘法时省掉8倍常数,然后就可以不开O2AC了。

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1000010,M=200;
int n,k,s,t;
struct data{
    ll a[M][M];
    data(){memset(a,0,sizeof a);}
}a;
data operator*(const data&a,const data&b){
    data ret;
    memset(ret.a,0x3f,sizeof ret.a);
    for(int i=0;i<k;++i){
        for(int j=0;j<k;++j){
            for(int l=0;l<k;++l){
                ret.a[i][j]=min(ret.a[i][j],a.a[i][l]+b.a[l][j]);
			}
		}
	}
	return ret;
}
data mpow(data a,int n){
    data ret=a;
    n--;
    while(n){
         if(n&1)ret=ret*a;
         n/=2;
         a=a*a;
	}
	return ret;
}
int tot,bian[N],nxt[N],head[N];
void add(int x,int y){
    tot++,bian[tot]=y,nxt[tot]=head[x],head[x]=tot;
}
struct edge{
    int x,y;
    ll z;
}e[M];
int vis[N];
vector<int>v;
void dfs(int x){
    if(vis[x])return;
    v.push_back(x);
    vis[x]=1;
	for(int i=head[x];i;i=nxt[i]){
	    int y=bian[i];
	    dfs(y);
	}
}
int main(){
    cin>>n>>k>>s>>t;
    memset(a.a,0x3f,sizeof a.a);
    for(int i=1;i<=k;++i){
        scanf("%lld%d%d",&e[i].z,&e[i].x,&e[i].y);
        add(e[i].x,e[i].y);
        add(e[i].y,e[i].x);
	}
	dfs(s);
	sort(v.begin(),v.end());
	for(int i=1;i<=k;++i){
		if(!vis[e[i].x])continue;
		int x=lower_bound(v.begin(),v.end(),e[i].x)-v.begin(),
		    y=lower_bound(v.begin(),v.end(),e[i].y)-v.begin();
	    a.a[y][x]=a.a[x][y]=min(a.a[x][y],e[i].z);
	    
	}
    data ret=mpow(a,n);
    s=lower_bound(v.begin(),v.end(),s)-v.begin();
    t=lower_bound(v.begin(),v.end(),t)-v.begin();
    cout<<ret.a[s][t]<<endl;
}
posted @ 2019-07-09 10:40  整理者  阅读(243)  评论(0编辑  收藏  举报