zjoj1706: [usaco2007 Nov]relays 奶牛接力跑
矩阵乘法(快速幂)
为说明方便,这里让\(k\)为点数,\(n\)为路径长度。
先将点都离散化,这样最后的点只有\(2k\)个。
先考虑一种暴力,每次用\(O(k^3)\)的复杂度来暴力更新,设当前长度\(l\)点的两两最短路矩阵为\(S\),现在要转移到\(l+1\)时的最短路矩阵\(T\)。我们考虑用每条边更新,对于某条从\(x\)连向\(y\)的长度为\(z\)的边,对于任一点\(i\),有:
\[T[i][y]=min(T[i][y],T[i][x]+z)
\]
另外,每次更新时,\(T\)矩阵的初始值为无限大。
然后我们就可以用\(O(nk^3)\)的复杂度去做这道题了。但这明显不行。
我们设没有直接连通的两个点距离为无限大,构建出邻接矩阵\(D\),就可以魔改一下上面的式子,改成:
\[T[i][j]=min(T[i][x]+D[x][j])
\]
其中\(x\)为自己枚举的中间节点,然后就出现的如下的代码:
for(int i=0;i<k;++i){
for(int j=0;j<k;++j){
for(int l=0;l<k;++l){
ret.a[i][j]=min(ret.a[i][j],a.a[i][l]+b.a[l][j]);
}
}
}
发现,这不是就是矩阵乘法吗?
因为取最小值满足可加性,所以使用矩阵快速幂是可行的。这样,我们就能把复杂度优化为\(O(lognk^3)\)
然后,我就不开O2过不了了。
我们发现从源点能到达的点数最多只有\(k+1\)(因为即使走过每条边都发现一个新节点,也只能发现这么多点。)所以我们可以只用源点能到的点进行离散化,可以将点数从\(2k\)到\(k\),从而在矩阵乘法时省掉8倍常数,然后就可以不开O2AC了。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1000010,M=200;
int n,k,s,t;
struct data{
ll a[M][M];
data(){memset(a,0,sizeof a);}
}a;
data operator*(const data&a,const data&b){
data ret;
memset(ret.a,0x3f,sizeof ret.a);
for(int i=0;i<k;++i){
for(int j=0;j<k;++j){
for(int l=0;l<k;++l){
ret.a[i][j]=min(ret.a[i][j],a.a[i][l]+b.a[l][j]);
}
}
}
return ret;
}
data mpow(data a,int n){
data ret=a;
n--;
while(n){
if(n&1)ret=ret*a;
n/=2;
a=a*a;
}
return ret;
}
int tot,bian[N],nxt[N],head[N];
void add(int x,int y){
tot++,bian[tot]=y,nxt[tot]=head[x],head[x]=tot;
}
struct edge{
int x,y;
ll z;
}e[M];
int vis[N];
vector<int>v;
void dfs(int x){
if(vis[x])return;
v.push_back(x);
vis[x]=1;
for(int i=head[x];i;i=nxt[i]){
int y=bian[i];
dfs(y);
}
}
int main(){
cin>>n>>k>>s>>t;
memset(a.a,0x3f,sizeof a.a);
for(int i=1;i<=k;++i){
scanf("%lld%d%d",&e[i].z,&e[i].x,&e[i].y);
add(e[i].x,e[i].y);
add(e[i].y,e[i].x);
}
dfs(s);
sort(v.begin(),v.end());
for(int i=1;i<=k;++i){
if(!vis[e[i].x])continue;
int x=lower_bound(v.begin(),v.end(),e[i].x)-v.begin(),
y=lower_bound(v.begin(),v.end(),e[i].y)-v.begin();
a.a[y][x]=a.a[x][y]=min(a.a[x][y],e[i].z);
}
data ret=mpow(a,n);
s=lower_bound(v.begin(),v.end(),s)-v.begin();
t=lower_bound(v.begin(),v.end(),t)-v.begin();
cout<<ret.a[s][t]<<endl;
}