题解 欢乐豆
首先有个显然地过分但我没有看出来的结论:若没有修改,从 \(u\) 到 \(v\) 的最短路长度是 \(a_u\)
那现在改了几条边,最短路可能就变成了 \(dis_{u, z}+a_z\)
发现这个 \(z\) 和 \(u\) 之间必须通过修改过的边连通,否则绕一下一定不优
于是将每个修改过的边连出的连通块单独拿出来考虑,跑个全源最短路
然后从连通块内的某个点 \(x\) 到连通块外的某个点 \(y\) 的最短路就是 \(min\{a_x, dis_{a, z}+a_z\}\)
特别注意一个细节,从一个连通块内的点到另一个连通块内的点的最短路可能会需要绕一下连通块外的点
但这部分没必要算到 \(dis_{x, y}\) 里,在统计答案时考虑也是等价的
然后考虑优化全源最短路的过程
发现许多边权都是一样的,于是可以线段树区间取min优化松弛过程
注意需要在线段树上打个删除标记以确定一个点不会被取出两次
复杂度 \(O(m^2logm)\)
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f3f3f3f
#define N 100010
#define ll long long
#define fir first
#define sec second
#define make make_pair
#define pb push_back
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, m;
int a[N];
namespace force{
int dis[510][510];
ll ans;
void solve() {
for (int i=1; i<=n; ++i) for (int j=1; j<=n; ++j) if (i!=j) dis[i][j]=a[i];
for (int i=1,u,v,w; i<=m; ++i) {
u=read(); v=read(); w=read();
dis[u][v]=w;
}
for (int k=1; k<=n; ++k)
for (int i=1; i<=n; ++i)
for (int j=1; j<=n; ++j)
dis[i][j]=min(dis[i][j], dis[i][k]+dis[k][j]);
for (int i=1; i<=n; ++i) for (int j=1; j<=n; ++j) if (i!=j) ans+=dis[i][j];
printf("%lld\n", ans);
exit(0);
}
}
namespace task{
ll ans;
int dsu[N], rk[N], tot;
ll dis[3010][3010];
bool vis[N], vis2[N];
pair<int, int> bge[N], tem[N];
vector<int> s[N];
struct cge{int u, v, w; inline void build(){u=read(); v=read(); w=read();}}cg[N];
vector<cge> rel[N];
inline int find(int p) {return dsu[p]==p?p:dsu[p]=find(dsu[p]);}
int tl[N<<2], tr[N<<2], mini[N<<2];
ll tag[N<<2], minn[N<<2];
bool full[N<<2];
#define tl(p) tl[p]
#define tr(p) tr[p]
#define tag(p) tag[p]
#define minn(p) minn[p]
#define mini(p) mini[p]
#define full(p) full[p]
inline void pushup(int p) {
full(p)=full(p<<1)&full(p<<1|1);
minn(p)=min(minn(p<<1), minn(p<<1|1));
if (minn(p<<1)==minn(p)) mini(p)=mini(p<<1);
else mini(p)=mini(p<<1|1);
}
inline void spread(int p) {
if (tag(p)==INF) return ;
if (!full(p<<1)) {minn(p<<1)=min(minn(p<<1), tag(p)); tag(p<<1)=min(tag(p<<1), tag(p));}
if (!full(p<<1|1)) {minn(p<<1|1)=min(minn(p<<1|1), tag(p)); tag(p<<1|1)=min(tag(p<<1|1), tag(p));}
tag(p)=INF;
}
void build(int p, int l, int r, ll val) {
tl(p)=l; tr(p)=r; tag(p)=INF; full(p)=0;
if (l==r) {minn(p)=val; mini(p)=l; return ;}
int mid=(l+r)>>1;
build(p<<1, l, mid, val);
build(p<<1|1, mid+1, r, val);
pushup(p);
}
void upd(int p, int pos, ll val) {
if (tl(p)==tr(p)) {if (!full(p)) minn(p)=min(minn(p), val); return ;}
spread(p);
int mid=(tl(p)+tr(p))>>1;
if (pos<=mid) upd(p<<1, pos, val);
else upd(p<<1|1, pos, val);
pushup(p);
}
void del(int p, int pos) {
if (tl(p)==tr(p)) {minn(p)=INF; full(p)=1; return ;}
spread(p);
int mid=(tl(p)+tr(p))>>1;
if (pos<=mid) del(p<<1, pos);
else del(p<<1|1, pos);
pushup(p);
}
void upd(int p, int l, int r, ll val) {
if (l<=tl(p)&&r>=tr(p)) {if (!full(p)) {minn(p)=min(minn(p), val); tag(p)=min(tag(p), val);} return ;}
spread(p);
int mid=(tl(p)+tr(p))>>1;
if (l<=mid) upd(p<<1, l, r, val);
if (r>mid) upd(p<<1|1, l, r, val);
pushup(p);
}
void query(int p, int* dis) {
if (tl(p)==tr(p)) {dis[tl(p)]=minn(p); return ;}
spread(p);
query(p<<1, dis); query(p<<1|1, dis);
}
void floyd(int id) {
int siz=s[id].size();
for (int i=0; i<siz; ++i) for (int j=0; j<siz; ++j) dis[i][j]=0;
for (int i=0; i<siz; ++i) for (int j=0; j<siz; ++j) if (i!=j) dis[i][j]=a[s[id][i]];
for (int i=0; i<siz; ++i) for (auto it:rel[s[id][i]]) dis[i][rk[it.v]]=it.w;
for (int k=0; k<siz; ++k) for (int i=0; i<siz; ++i) for (int j=0; j<siz; ++j) dis[i][j]=min(dis[i][j], dis[i][k]+dis[k][j]);
}
void dijkstra(int id) {
int siz=s[id].size();
for (int i=0; i<siz; ++i) {
// cout<<"i: "<<i<<endl;
dis[i][i]=0;
build(1, 0, siz-1, INF-1);
upd(1, i, 0);
for (int j=1; j<=siz; ++j) {
// cout<<"j: "<<j<<endl;
int t=mini[1];
// cout<<"t: "<<t<<endl;
dis[i][t]=minn[1];
del(1, t);
// cout<<2<<endl;
if (!rel[s[id][t]].size()) upd(1, 0, siz, dis[i][t]+a[s[id][t]]);
else {
// cout<<3<<endl;
for (int k=0; k<rel[s[id][t]].size(); ++k) tem[k+1]=make(rk[rel[s[id][t]][k].v], rel[s[id][t]][k].w);
// cout<<5<<endl;
tem[rel[s[id][t]].size()+1]=make(siz, 0);
sort(tem+1, tem+rel[s[id][t]].size()+2);
// cout<<4<<endl;
// cout<<"tem: "; for (int i=1; i<=rel[s[id][t]].size()+1; ++i) cout<<tem[i].fir<<','<<tem[i].sec<<' '; cout<<endl;
if (tem[1].fir>0) upd(1, 0, tem[1].fir-1, dis[i][t]+a[s[id][t]]);
// cout<<"size: "<<rel[s[id][t]].size()<<endl;
for (int k=1; k<=rel[s[id][t]].size(); ++k) {
// cout<<"k: "<<k<<endl;
upd(1, tem[k].fir, dis[i][t]+tem[k].sec);
// cout<<7<<' '<<tem[k].fir<<' '<<tem[k].sec<<endl;
// cout<<7<<endl;
if (tem[k+1].fir>tem[k].fir+1) {
// cout<<tem[k].fir+1<<' '<<tem[k+1].fir-1<<endl;
upd(1, tem[k].fir+1, tem[k+1].fir-1, dis[i][t]+a[s[id][t]]);
}
}
// cout<<6<<endl;
}
}
// cout<<"dis: "; for (int j=0; j<siz; ++j) cout<<dis[i][j]<<' '; cout<<endl;
}
}
void calc(int k, vector<int>& s) {
// cout<<"calc: "<<k<<endl;
// floyd(k);
dijkstra(k);
int siz=s.size();
// cout<<"siz: "<<siz<<endl;
for (int i=0; i<siz; ++i) {
// cout<<"i: "<<i<<endl;
ll len=INF;
for (int j=0; j<siz; ++j) len=min(len, dis[i][j]+a[s[j]]);
// cout<<"len: "<<len<<endl;
ans+=1ll*len*(n-siz);
// cout<<len*(n-siz)<<endl;
}
// cout<<"ans1: "<<ans<<endl;
for (int i=0; i<siz; ++i) {
// cout<<"i: "<<i<<' '<<s[i]<<endl;
ll path=siz==n?INF:a[s[i]];
for (int j=1; j<=n; ++j) if (find(bge[j].sec)!=k) {path+=bge[j].fir; break;}
for (int j=0; j<siz; ++j) if (i!=j) ans+=min(dis[i][j], path); //, cout<<min(dis[i][j], path)<<endl;
// cout<<path<<endl;
}
// cout<<"ans: "<<ans<<endl;
}
void solve() {
// cout<<double(sizeof(dis))/1024/1024<<endl;
for (int i=1; i<=n; ++i) dsu[i]=i;
for (int i=1; i<=n; ++i) bge[i]=make(a[i], i);
sort(bge+1, bge+n+1);
for (int i=1; i<=m; ++i) {
cg[i].build(); rel[cg[i].u].pb(cg[i]);
vis[cg[i].u]=vis[cg[i].v]=1;
dsu[find(cg[i].u)]=find(cg[i].v);
}
// cout<<"dsu: "; for (int i=1; i<=n; ++i) cout<<find(i)<<' '; cout<<endl;
for (int i=1; i<=n; ++i) if (vis[i]) {s[find(i)].pb(i); rk[i]=s[find(i)].size()-1;}
for (int i=1; i<=n; ++i) {
if (!vis[i]) ans+=1ll*a[i]*(n-1);
else if (!vis2[find(i)]) {
calc(find(i), s[find(i)]);
vis2[find(i)]=1;
}
}
printf("%lld\n", ans);
exit(0);
}
}
signed main()
{
freopen("happybean.in", "r", stdin);
freopen("happybean.out", "w", stdout);
n=read(); m=read();
for (int i=1; i<=n; ++i) a[i]=read();
// force::solve();
task::solve();
return 0;
}