Counting Shortcuts(CF 1650 G)
题目大意
有\(T\)组样例,每组样例有\(n\)个点,\(m\)条边(无重边,无自环),起点为\(s\),终点为\(t\),问起点到终点的路径(可以不是简单路径)中,最短路的条数加上次短路的条数,答案对\(1e9+7\)取模。\((1\leq T\leq 10^4,2\leq n\leq 2*10^5,1\leq m\leq 2*10^5,s\not=t)\)
思路
虽然题目说可以不是简单路径,但是很容易能发现,真正能计数的还是简单路径。所以我们就可以先预处理出第\(i\)个点到起点的最短距离\(dep[i]\),然后根据\(dep[i]\)从小到大对\(i\)进行排序,不知道叫啥名儿,就先暂且叫它\(bfs\)序吧。然后我们可以再来一遍\(bfs\),这时候考虑\(dp\)转移。我们定义\(dp1[i]\)为\(s\)到\(i\)的路径中最短路的条数,\(dp2[i]\)为\(s\)到\(i\)的路径中次短路的条数,然后我们根据刚才处理出来的\(bfs\)序,按照\(bfs\)序的顺序,对于点\(u\),遍历与他相邻的点\(v\),如果\(dep[u]==dep[v]\),那就可以\(dp2[v]+=dp1[u]\),如果\(dep[u]+1==dep[v]\),就可以\(dp1[v]+=dp1[u],dp2[v]+=dp2[u]\)。但是这样还有一个问题,就是\(dp1[u]\)还没完全更新完你就在拿它更新\(dp2[v]\)了,于是我们可以想到,对于\(dep\)相同的点,我们都\(for\)两遍,第一遍遍历与之\(dep\)相同的点并更新,第二遍再更新第二种情况,这样就可以完美解决问题了。
代码
#include<bits/stdc++.h>
using namespace std;
long long mod=1e9+7;
int n,m;
int s,t;
const int maxn=200005;
struct EDGE
{
int next,to;
}edge[maxn<<1];
int head[maxn];
int cnt;
void add(int u,int v)
{
edge[cnt].to=v;
edge[cnt].next=head[u];
head[u]=cnt++;
}
struct node
{
int id,dep;
};
int dep[maxn];
bool vis[maxn];
vector<int>dots;
queue<node>q;
long long dp1[maxn],dp2[maxn];
int main()
{
memset(head,-1,sizeof(head));
int _;
scanf("%d",&_);
while(_--)
{
scanf("%d%d",&n,&m);
scanf("%d%d",&s,&t);
for(int i=1;i<=m;i++)
{
int u,v;
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
q.push({s,0});
dots.push_back(s);
vis[s]=1;
function<void()>bfs1=[&]()
{
while(!q.empty())
{
auto [x,Dep]=q.front();
q.pop();
for(int i=head[x];~i;i=edge[i].next)
{
int it=edge[i].to;
if(!vis[it])
{
dep[it]=Dep+1;
vis[it]=1;
dots.push_back(it);
q.push({it,Dep+1});
}
}
}
};
bfs1();
dp1[s]=1;
int siz=dots.size();
for(int i=0;i<siz;)
{
int j=i;
while(j<siz&&dep[dots[i]]==dep[dots[j]])j++;
for(int k=i;k<j;k++)
{
for(int l=head[dots[k]];~l;l=edge[l].next)
{
int it=edge[l].to;
if(dep[dots[k]]==dep[it])dp2[it]=(dp2[it]+dp1[dots[k]])%mod;
}
}
for(int k=i;k<j;k++)
{
for(int l=head[dots[k]];~l;l=edge[l].next)
{
int it=edge[l].to;
if(dep[dots[k]]+1==dep[it])
{
dp1[it]=(dp1[it]+dp1[dots[k]])%mod;
dp2[it]=(dp2[it]+dp2[dots[k]])%mod;
}
}
}
i=j;
}
printf("%lld\n",(dp1[t]+dp2[t])%mod);
if(_)
{
cnt=0;
for(int i=1;i<=n;i++)
{
dep[i]=0;
vis[i]=0;
head[i]=-1;
dp1[i]=dp2[i]=0;
}
for(int i=0;i<cnt;i++)
{
edge[i].to=edge[i].next=0;
}
dots.clear();
while(!q.empty())q.pop();
}
}
return 0;
}