NOIP模拟 规避(最短路计数+乘法原理)

内网传送门

【题目分析】

第一次接触最短路计数。。。。。。qwq

首先,不相遇情况复杂,所以直接选择总方案-相遇方案。

显然,因为最短路长度一定,两人速度相同,那么他们如果相遇,那么一定在最短路的中点。

所以这个中点就有两个情况:在点上、在边上。

考虑在点上,我们跑正反两遍Dijkstra处理所有点到s和到t的最短距离,那么显然如果中点在点上,先判断该点是否在最短路径上,如果是就看是否有dis[i->s]=dis[i->t]。那么方案数就为((s->i的最短路数)*(t->i的最短路数))^2。

考虑在边上,首先设该边两点分别为u,v,如果符合,那么就有dis[u->s]+dis[v->t]+w[u->v]=最短路径长度,abs(dis[u->s]-dis[v->t])<w[u->v](否则中点不在该边上),合法方案数为((s->u的最短路数)*(t->v的最短路数))^2。

【代码~】

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const LL MAXN=1e5+10;
const LL MAXM=4e5+10;
const LL INF=0x3f3f3f3f;
const LL mod=1e9+7;

LL n,m,cnt,s,t;
LL ans;
LL head[MAXN],dis[MAXN],disS[MAXN],disT[MAXN],vis[MAXN],cntt1[MAXN],cntt2[MAXN];
LL nxt[MAXM],to[MAXM],w[MAXM],from[MAXM];

LL Read(){
	LL i=0,f=1;
	char c;
	for(c=getchar();(c>'9'||c<'0')&&c!='-';c=getchar());
	if(c=='-')
	  f=-1,c=getchar();
	for(;c>='0'&&c<='9';c=getchar())
	  i=(i<<3)+(i<<1)+c-'0';
	return i*f;
}

LL mul(LL x,LL y){
	return x*y%mod;
}

void add(LL x,LL y,LL z){
	nxt[cnt]=head[x];
	head[x]=cnt;
	from[cnt]=x;
	to[cnt]=y;
	w[cnt]=z;
	cnt++;
}

void dijkstra(LL type){
	priority_queue<pair<LL,LL> > q;
	memset(vis,0,sizeof(vis));
	memset(dis,INF,sizeof(dis));
	q.push(make_pair(0,s));
	dis[s]=0;
	if(!type)
	  cntt1[s]=1;
	else
	  cntt2[s]=1;
	while(!q.empty()){
		LL u=q.top().second;
		q.pop();
		if(vis[u])
		  continue;
		vis[u]=1;
		for(LL i=head[u];i!=-1;i=nxt[i]){
			LL v=to[i];
			if(dis[v]>dis[u]+w[i]){
				dis[v]=dis[u]+w[i];
				if(type)
				  cntt2[v]=cntt2[u];
				else
				  cntt1[v]=cntt1[u];
				q.push(make_pair(-dis[v],v));
			}
			else{
				if(dis[v]==dis[u]+w[i]){
				    if(!type)
					  cntt1[v]=cntt1[u]+cntt1[v];
					else
					  cntt2[v]=cntt2[u]+cntt2[v];
				}
			}
		}
	}
	if(!type)
	  memcpy(disS,dis,sizeof(dis));
	else
	  memcpy(disT,dis,sizeof(dis));
}

int main(){
	memset(head,-1,sizeof(head));
	n=Read(),m=Read();
	s=Read(),t=Read();
	for(LL i=1;i<=m;++i){
		LL x=Read(),y=Read(),z=Read();
		add(x,y,z),add(y,x,z);
	}
	dijkstra(0);
	swap(s,t);
	dijkstra(1);
	LL diss=disT[t];
	ans=mul(cntt1[s],cntt1[s]);
	for(int i=1;i<=n;++i){
		if(disS[i]+disT[i]!=diss)
		  continue;
		if(disS[i]==disT[i])
		  ans=(ans-mul(mul(cntt1[i],cntt2[i]),mul(cntt1[i],cntt2[i]))+mod)%mod;
		for(int j=head[i];j!=-1;j=nxt[j]){
			int v=to[j];
			if(disS[i]+disT[v]+w[j]!=diss)
			  continue;
			if(abs(disS[i]-disT[v])<w[j])
			  ans=(ans-mul(mul(cntt1[i],cntt2[v]),mul(cntt1[i],cntt2[v]))+mod)%mod;
		}
	}
	cout<<ans;
	return 0;
}

 

posted @ 2018-11-05 18:22  Ishtar~  阅读(142)  评论(0编辑  收藏  举报