xay loves Floyd

题意:\(floyd\)被写成了这个样子:

for i from 1 to n
  for j from 1 to n
    for k from 1 to n
      dis[i][j] <- min(dis[i][j], dis[i][k] + dis[k][j])

求最后有多少个位置仍然是相等的。

解法1:这个错误的东西其实就是枚举每个点然后去更新其他的点,那么我们模拟这个过程去做,对于一个点对\((i,j)\),我们需要知道是否存在一个中转点\(k\)使得\((i,k)\)是最短路边,\((k,j)\)是最短路边,并且\((i,k)+(k,j)\)为最短路,所以做法就是对于每个根建出最短路的\(DAG\),拓扑对每个点求出每个点会有多少点能够到达它,判断\((i,j)\)是不是最短路时,用三个\(bitset\)与一下就行了。

解法2:直接去跑错误的做法,先把原图上的边权=最短路的边权加入到边集中,然后再对每个起点跑最短路,对于起点\(s\),可以发现它只有第一跳是可以向前跳的之外其他的都只能向后跳,在做\(spfa\)的时候记录一下就可以了,在做的过程中发现有当前最短路已经完成的时候就把边加进去。

#include <bits/stdc++.h>
#define N 2009
#define mm make_pair
using namespace std;
typedef long long ll;
const ll mod=99824353;
int tot,n,m,head[N],dis[N][N],p[N][N];
bool vis[N];
inline ll rd(){
    ll x=0;char c=getchar();bool f=0;
    while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    return f?-x:x;
}
struct edge{
	int n,to,l;
}e[1000009];
inline void add(int u,int v,int l){
	e[++tot].n=head[u];
	e[tot].to=v;
	head[u]=tot;
	e[tot].l=l;
}
void spfa1(int id){
	queue<int>q;
	memset(dis[id],0x3f,sizeof(dis[id]));
	dis[id][id]=0;
	q.push(id);
	while(!q.empty()){
		int u=q.front();q.pop();vis[u]=0;
		for(int i=head[u];i;i=e[i].n){
			int v=e[i].to;
			if(dis[id][v]>dis[id][u]+e[i].l){
				dis[id][v]=dis[id][u]+e[i].l;
				if(!vis[v]){
					vis[v]=1;
					q.push(v);
				}
			}
	    }
	}
}
void spfa2(int id){
	queue<pair<int,int> >q;
	for(int i=head[id];i;i=e[i].n)q.push(mm(0,e[i].to));
	while(!q.empty()){
		pair<int,int> nw=q.front();q.pop();
		int u=nw.second;vis[u]=0;
		for(int i=head[u];i;i=e[i].n){
			int v=e[i].to;
			if(v<nw.first)continue;
			if(dis[id][v]==p[id][v]||dis[id][v]>1e9)continue;
			if(dis[id][v]==dis[id][u]+e[i].l){
				p[id][v]=dis[id][v];
			    add(id,v,dis[id][v]);
				if(!vis[v]){
					vis[v]=1;
					q.push(mm(v,v));
				}
			}
		}
    }
}
int main(){
   n=rd();m=rd();
   memset(p,0x3f,sizeof(p));
   int u,v,w;
   for(int i=1;i<=m;++i){
   	   u=rd();v=rd();w=rd();
   	   p[u][v]=w;
   	   add(u,v,w);
   }
   for(int i=1;i<=n;++i)p[i][i]=0;
   for(int i=1;i<=n;++i)spfa1(i);
   memset(head,0,sizeof(head));tot=0;
   for(int i=1;i<=n;++i)
       for(int j=1;j<=n;++j)if(i!=j&&p[i][j]<1e9&&p[i][j]==dis[i][j]){
       	   add(i,j,dis[i][j]);
       }
   for(int i=1;i<=n;++i)spfa2(i);
   int ans=0;
   for(int i=1;i<=n;++i)
       for(int j=1;j<=n;++j)if(dis[i][j]==p[i][j])ans++;
    printf("%d\n",ans);
   return 0;
}
posted @ 2021-08-08 20:46  comld  阅读(118)  评论(0编辑  收藏  举报