追捕嘉然2.0
简要题意
个点, 条带权边的无向图,另外还有 条连接 和 的边 。
在保证每个点到 的最短距离不变的前提下,最多可以删掉 条边中的的几条。
解析
最暴力的想法即逐条断掉网道之后跑一遍最短路即可,但 也是 级别的,显然不可行。
显然,如果网道不在任何一条最短路上或者到与网道相连的点的最短路不止一条,那这条边就一定可以删。
在跑最短路时,我们开一个tot[v]
用于存储 号点在满足 为一条的最短路上点 的数量。其中 可以经过多条边甚至是零条,但 一定经过一条边。因此,我们其实是在统计形如 点对的边的数量。考虑到重边,形如 点对的边可能不止一对,所以当 为一条最短路时 有两条边且边权相等,必须统计两次。代码如下:
int v=edge[i].to;
if(ans[v]==ans[u]+edge[i].w) tot[v]++;
if(ans[v]>ans[u]+edge[i].w) {
ans[v]=ans[u]+edge[i].w;
tot[v]=1;
if(!vis[v])
q.push((priority) {ans[v],v});
}
在ans[v]==ans[u]+edge[i].w
时,即出现另外一条形如 点对的边,则tot[v]++
在进行松弛时,tot[v]
中记录的最短路已经不是最短路了,所以直接变成1(即当前这条),也有可能是该点之前压根没有经过,也变成1。
其实按照如下写法统计最短路的数量理论上可行。
int v=edge[i].to;
if(ans[v]==ans[u]+edge[i].w) tot[v]+=tot[u];
if(ans[v]>ans[u]+edge[i].w) {
ans[v]=ans[u]+edge[i].w;
tot[v]=tot[u];
if(!vis[v])
q.push((priority) {ans[v],v});
}
但是可能会有如下情况:
显然这样的路径条数是指数级的,所以不可取。
统计答案时,分两种情况:
-
到公交站的最短路比网道短,直接删了。
-
到公交站的最短路与网道相等,这说明这也是一条最短路。如果
tot[v]>1
,那说明不止一条有从1出发到的最短路,删就完了,这里一定要把tot[v]--
,因为有重边(感谢 指出)。
最后贴一下代码
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e6+12;
struct node {
int nxt,to,w;
} edge[maxn];
int head[maxn],cnt=0;
int n,m,k;
int a[maxn][3];
int tot[maxn];
void add(int u,int v,int w) {
cnt++;
edge[cnt].nxt=head[u];
edge[cnt].w=w;
edge[cnt].to=v;
head[u]=cnt;
}
int ans[maxn];
bool vis[maxn];
struct priority {
int ans,id;
bool operator <(priority x)const {
return x.ans<ans;
}
};
priority_queue<priority> q;
void dj() {
memset(ans,0x3f,sizeof(ans));
memset(vis,0,sizeof(vis));
ans[1]=0;
q.push((priority) {
0,1
});
// vis[1]=1;
while(!q.empty()) {
priority temp=q.top();
int u=temp.id;
q.pop();
if(vis[u]) continue;
vis[u]=1;
for(int i=head[u]; i; i=edge[i].nxt) {
int v=edge[i].to;
if(ans[v]==ans[u]+edge[i].w) tot[v]++;
if(ans[v]>ans[u]+edge[i].w) {
ans[v]=ans[u]+edge[i].w;
tot[v]=1;
if(!vis[v])
q.push((priority) {
ans[v],v
});
}
}
}
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin>>n>>m>>k;
for(int i=1; i<=m; i++) {
int u,v,w;
cin>>u>>v>>w;
add(u,v,w);
add(v,u,w);
}
for(int i=1; i<=k; i++) {
int s,w;
cin>>s>>w;
a[i][1]=s,a[i][2]=w;
add(1,s,w);
add(s,1,w);
}
dj();
int res=0;
for(int i=1; i<=k; i++) {
int v=a[i][1],w=a[i][2];
if(ans[v]<w)
res++;
if(ans[v]==w) {
if(tot[v]>1) {
res++;
tot[v]--;
}
}
}
cout<<res<<endl;
return 0;
}