Dist
数据范围:\(1<=n<=10^5,1<=k<=18,1<=k_i<=10^7,\sum |s_i|<=3*10^5\)
Solution
因为\(k\)比较小所以显然是拿\(k\)来搞事
比较简单粗暴的想法是,我们直接预处理出每个完全图中的一个点(好吧其实就是团==)到其他的完全图中一个点的最短距离,然后用这个东西来计算两个点之间的最短路就方便很多了
首先可以用一个\(mp[i][j]\)表示团\(i\)到团\(j\)的最短距离,这个数组在初始化的话直枚举每个节点,然后任选这个点所在团中的两个,取两个团的边权的最小值作为直接连接这两个团的距离,也就是初始的\(mp[i][j]\),然后floyd一下就得到完整的\(mp\)数组了
接下来考虑怎么用这个玩意计算两个节点的最短路
考虑固定一个起点\(x\),我们要计算\(x\)到其他点的最短路,我们可以先用预处理出的\(mp\)数组算出\(x\)到任意一个团\(j\)中的一个点(非\(x\))的最短路,记\(dis[j]\),具体一点的话就是枚举\(x\)先走一步到团\(i\)与团\(j\)的最短路的起点,然后直接走最短路到团\(j\)中某个节点,也就是\(min(k[i]+mp[i][j])\)
然后\(x\)到\(y\)的最短路一定是\(y\)所属的团的\(dis\)值的最小值,所以我们考虑将\(dis\)排序,从小到大计算贡献,当前团\(j\)的贡献应该就是\(dis[j]\)乘上当前团中没有被前面的团所包含的节点的数量
现在考虑这个数量怎么计算
因为\(k\)比较小,所以我们可以把当前已经考虑完的团具体是哪些给压成一个二进制数\(nowst\)(为\(1\)表示还没有考虑),每一个原图中的节点\(x\)所属的团也压成一个二进制数\(st[x]\),如果说当前考虑的团中包含的一个节点\(x\)满足\(st[x]|nowst=nowst\),那么说明这个节点不属于前面考虑的任何一个团,即有\(1\)的贡献,所以现在的问题就变成了对于每一个团\(i\)我们要预处理一个\(cnt[i][j]\)表示该团满足\(st[x]|j=j\)的节点\(x\)有多少个,这个东西的计算。。跟某fwt题里面的操作一样
然后我们只要枚举固定的起点\(x\)然后按照上面的步骤计算贡献就ok了
需要注意的是每次计算贡献的时候因为我们的\(dis\)处理出来的是\(x\)走到另外一个点的距离,也就是意味着我们在算的时候会把\(x\)到\(x\)的距离算成\(x\)走到排序后第一个团中某个其他节点的距离,所以需要判一下(减掉就好了),以及最后的答案要除以\(2\)(因为每条边算了\(2\)次)
Code
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int N=1e5+10,ST=(1<<18)+10;
const ll inf=1LL<<60;
ll mp[20][20];
int rec[20][N],cnt[20][ST],w[20];
ll dis[20];
int st[N],ord[N];
int n,m;
ll ans;
int St(int x){return 1<<x-1;}
int in(int st,int x){return st>>x-1&1;}
void prework(){
for (int i=1;i<=n;++i){
for (int j=1;j<=m;++j){
if (!in(st[i],j)) continue;
++cnt[j][st[i]];
for (int k=1;k<=m;++k)
if (in(st[i],k))
mp[j][k]=min(mp[j][k],1LL*w[k]);
}
}
for (int k=1;k<=m;++k)
for (int i=1;i<=m;++i){
if (i==k||mp[i][k]==inf) continue;
for (int j=1;j<=m;++j){
if (j==k||j==i||mp[k][j]==inf) continue;
mp[i][j]=min(mp[i][j],mp[i][k]+mp[k][j]);
}
}
for (int i=1;i<=m;++i)
for (int j=0;j<m;++j)
for (int k=0;k<1<<m;++k)
if (!(k>>j&1))
cnt[i][k|(1<<j)]+=cnt[i][k];
}
bool cmp(int x,int y){return dis[x]<dis[y];}
void solve(){
ll tmp;
int nowst;
for (int x=1;x<=n;++x){
for (int i=1;i<=m;++i) dis[i]=inf,ord[i]=i;
for (int i=1;i<=m;++i){
if (!in(st[x],i)) continue;
for (int j=1;j<=m;++j)
dis[j]=min(dis[j],mp[i][j]+w[i]);
}
sort(ord+1,ord+1+m,cmp);
nowst=(1<<m)-1;
tmp=-dis[ord[1]];
for (int i=1;i<=m;++i){
tmp+=dis[ord[i]]*cnt[ord[i]][nowst];
nowst^=St(ord[i]);
}
ans+=tmp;
}
printf("%lld\n",ans/2);
}
int main(){
#ifndef ONLINE_JUDGE
freopen("a.in","r",stdin);
#endif
int x;
scanf("%d%d",&n,&m);
for (int i=1;i<=m;++i){
for (int j=1;j<=m;++j) mp[i][j]=i==j?0:inf;
scanf("%d%d",&w[i],&rec[i][0]);
for (int j=1;j<=rec[i][0];++j){
scanf("%d",&x);
rec[i][j]=x;
st[x]|=St(i);
}
}
prework();
solve();
}