NOIP模拟 规避(最短路计数+乘法原理)
内网传送门
【题目分析】
第一次接触最短路计数。。。。。。qwq
首先,不相遇情况复杂,所以直接选择总方案-相遇方案。
显然,因为最短路长度一定,两人速度相同,那么他们如果相遇,那么一定在最短路的中点。
所以这个中点就有两个情况:在点上、在边上。
考虑在点上,我们跑正反两遍Dijkstra处理所有点到s和到t的最短距离,那么显然如果中点在点上,先判断该点是否在最短路径上,如果是就看是否有dis[i->s]=dis[i->t]。那么方案数就为((s->i的最短路数)*(t->i的最短路数))^2。
考虑在边上,首先设该边两点分别为u,v,如果符合,那么就有dis[u->s]+dis[v->t]+w[u->v]=最短路径长度,abs(dis[u->s]-dis[v->t])<w[u->v](否则中点不在该边上),合法方案数为((s->u的最短路数)*(t->v的最短路数))^2。
【代码~】
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const LL MAXN=1e5+10;
const LL MAXM=4e5+10;
const LL INF=0x3f3f3f3f;
const LL mod=1e9+7;
LL n,m,cnt,s,t;
LL ans;
LL head[MAXN],dis[MAXN],disS[MAXN],disT[MAXN],vis[MAXN],cntt1[MAXN],cntt2[MAXN];
LL nxt[MAXM],to[MAXM],w[MAXM],from[MAXM];
LL Read(){
LL i=0,f=1;
char c;
for(c=getchar();(c>'9'||c<'0')&&c!='-';c=getchar());
if(c=='-')
f=-1,c=getchar();
for(;c>='0'&&c<='9';c=getchar())
i=(i<<3)+(i<<1)+c-'0';
return i*f;
}
LL mul(LL x,LL y){
return x*y%mod;
}
void add(LL x,LL y,LL z){
nxt[cnt]=head[x];
head[x]=cnt;
from[cnt]=x;
to[cnt]=y;
w[cnt]=z;
cnt++;
}
void dijkstra(LL type){
priority_queue<pair<LL,LL> > q;
memset(vis,0,sizeof(vis));
memset(dis,INF,sizeof(dis));
q.push(make_pair(0,s));
dis[s]=0;
if(!type)
cntt1[s]=1;
else
cntt2[s]=1;
while(!q.empty()){
LL u=q.top().second;
q.pop();
if(vis[u])
continue;
vis[u]=1;
for(LL i=head[u];i!=-1;i=nxt[i]){
LL v=to[i];
if(dis[v]>dis[u]+w[i]){
dis[v]=dis[u]+w[i];
if(type)
cntt2[v]=cntt2[u];
else
cntt1[v]=cntt1[u];
q.push(make_pair(-dis[v],v));
}
else{
if(dis[v]==dis[u]+w[i]){
if(!type)
cntt1[v]=cntt1[u]+cntt1[v];
else
cntt2[v]=cntt2[u]+cntt2[v];
}
}
}
}
if(!type)
memcpy(disS,dis,sizeof(dis));
else
memcpy(disT,dis,sizeof(dis));
}
int main(){
memset(head,-1,sizeof(head));
n=Read(),m=Read();
s=Read(),t=Read();
for(LL i=1;i<=m;++i){
LL x=Read(),y=Read(),z=Read();
add(x,y,z),add(y,x,z);
}
dijkstra(0);
swap(s,t);
dijkstra(1);
LL diss=disT[t];
ans=mul(cntt1[s],cntt1[s]);
for(int i=1;i<=n;++i){
if(disS[i]+disT[i]!=diss)
continue;
if(disS[i]==disT[i])
ans=(ans-mul(mul(cntt1[i],cntt2[i]),mul(cntt1[i],cntt2[i]))+mod)%mod;
for(int j=head[i];j!=-1;j=nxt[j]){
int v=to[j];
if(disS[i]+disT[v]+w[j]!=diss)
continue;
if(abs(disS[i]-disT[v])<w[j])
ans=(ans-mul(mul(cntt1[i],cntt2[v]),mul(cntt1[i],cntt2[v]))+mod)%mod;
}
}
cout<<ans;
return 0;
}