【P1261】服务器储存信息问题【最短路】
题目大意:
题目链接:https://www.luogu.org/problemnew/show/P1261
Byteland王国准备在各服务器间建立大型网络并提供多种服务。
网络由n台服务器组成,用双向的线连接。两台服务器之间最多只能有一条线直接连接,同时,每台服务器最多只能和10台服务器直接连接,但是任意两台服务器间必然存在一条路径将它们连接在一起。每条传输线都有一个固定传输的速度。δ(V, W)表示服务器V和W之间的最短路径长度,且对任意的V有δ(V, V)=0。
有些服务器比别的服务器提供更多的服务,它们的重要程度要高一些。我们用r(V)表示服务器V的重要程度(rank)。rank越高的服务器越重要。
每台服务器都会存储它附近的服务器的信息。当然,不是所有服务器的信息都存,只有感兴趣的服务器信息才会被存储。服务器V对服务器W感兴趣是指,不存在服务器U满足,r(U)>r(W)且δ(V, U)<=δ(V, W)。
举个例子来说,所有具有最高rank的服务器都会被别的服务器感兴趣。如果V是一台具有最高rank的服务器,由于δ(V, V)=0,所以V只对具有最高rank的服务器感兴趣。我们定义B(V)为V感兴趣的服务器的集合。
我们希望计算所有服务器储存的信息量,即所有服务器的|B(V)|之和。Byteland王国并不希望存储大量的数据,所以所有服务器存储的数据量(|B(V)|之和)不会超过30n。
你的任务是写一个程序,读入Byteland王国的网络分布,计算所有服务器存储的数据量。
思路:
显然暴力做法就是跑遍单元最短路。
考虑到答案不会太大(所有服务器存储的数据量(之和)不会超过),所以可以考虑对暴力算法优化。
设表示距离点最近的的点的距离。
那么如果我们在跑点的最短路时,最短路跑到点,接下来要转移到点,但是有,那么就有
所以这样显然不会更新答案。
所以我们根本不必把扔进(优先)队列里。
这样就可以过了。
代码:
#include <queue>
#include <cstdio>
#include <cstring>
#define mp make_pair
using namespace std;
const int N=30010;
int n,m,tot,ans,rank[N],f[N][15],dis[N],head[N];
bool vis[N],flag[N];
queue<int> r[15];
struct edge
{
int next,to,dis;
}e[N*10];
void add(int from,int to,int dis)
{
e[++tot].to=to;
e[tot].dis=dis;
e[tot].next=head[from];
head[from]=tot;
}
void dij_rank(int rk)
{
memset(vis,0,sizeof(vis));
priority_queue<pair<int,int> > q;
while (r[rk].size())
{
int x=r[rk].front();
f[x][rk]=0;
q.push(mp(0,x));
r[rk].pop();
}
while (q.size())
{
int u=q.top().second,v;
q.pop();
if (vis[u]) continue;
vis[u]=1;
for (int i=head[u];~i;i=e[i].next)
{
v=e[i].to;
if (f[v][rk]>f[u][rk]+e[i].dis)
{
f[v][rk]=f[u][rk]+e[i].dis;
q.push(mp(-f[v][rk],v));
}
}
}
}
void dij(int S)
{
memset(vis,0,sizeof(vis));
memset(dis,0x3f3f3f3f,sizeof(dis));
memset(flag,0,sizeof(flag));
priority_queue<pair<int,int> > q;
q.push(mp(0,S));
dis[S]=0;
while (q.size())
{
int u=q.top().second,v;
if (!flag[u])
{
flag[u]=1;
ans++;
}
q.pop();
if (vis[u]) continue;
vis[u]=1;
for (int i=head[u];~i;i=e[i].next)
{
v=e[i].to;
if (dis[v]>dis[u]+e[i].dis)
{
dis[v]=dis[u]+e[i].dis;
if (dis[v]<f[v][rank[S]+1])
q.push(mp(-dis[v],v));
}
}
}
}
int main()
{
memset(head,-1,sizeof(head));
memset(f,0x3f3f3f3f,sizeof(f));
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++)
{
scanf("%d",&rank[i]);
r[rank[i]].push(i);
}
for (int i=1;i<=m;i++)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
add(x,y,z);
add(y,x,z);
}
for (int i=10;i>=1;i--)
{
dij_rank(i);
if (i<10)
for (int j=1;j<=n;j++)
if (f[j][i]>f[j][i+1]) f[j][i]=f[j][i+1];
}
for (int i=1;i<=n;++i)
dij(i);
printf("%d",ans);
return 0;
}