【洛谷P6085】吃货 JYY
题目
题目链接:https://www.luogu.com.cn/problem/P6085
世界上一共有 \(N\) 个 JYY 愿意去的城市,分别从 \(1\) 编号到 \(N\)。JYY 选出了 \(K\) 个他一定要乘坐的航班。除此之外,还有 \(M\) 个 JYY 没有特别的偏好,可以乘坐也可以不乘坐的航班。
一个航班我们用一个三元组 \((x,y,z)\) 来表示,意义是这趟航班连接城市 \(x\) 和 \(y\),并且机票费用是 \(z\)。每个航班都是往返的,所以 JYY 花费 \(z\) 的钱,既可以选择从 \(x\) 飞往 \(y\),也可以选择从 \(y\) 飞往 \(x\)。
南京的编号是 \(1\),现在 JYY 打算从南京出发,乘坐所有 K 个航班,并且最后回到南京,请你帮他求出最小的花费。
\(n\leq 13,k\leq 78,m\leq 200\)。
思路
最终选出的肯定是一个回路。所以我们可以把每个点度数的奇偶性状压一下进行 dp。
设 \(f[s]\) 表示状态为 \(s\) 时的最小代价。对于每一个点,\(0\) 表示不在当前连通块中,\(1\) 表示度数为奇数,\(2\) 表示度数为偶数。把这个三进制状压起来当做状态 \(s\)。
转移的话就枚举在连通块内的一个点 \(i\),以及一个连通块外的点 \(j\),然后考虑把 \(i\) 和 \(j\) 连起来。
如果 \(i,j\) 之间有必须走的边,那么直接转移并且不用计算长度。否则需要更新 \(i,j\) 的度数并加上 \(i,j\) 之间最短路的长度。
统计答案需要再预处理出 \(g[s]\) 表示集合 \(s\) 内的点度数是奇数,把他们两两连接的最小代价。注意这里的 \(s\) 是二进制。
最后枚举每一种三进制的状态,把其中度数为奇数的点拿出来加上 \(g\) 的贡献,取最小值后加上所有必须走的边的长度即可。
时间复杂度 \(O(n^2(3^n+2^n))\)。
代码
#include <bits/stdc++.h>
using namespace std;
const int N=15,M=1594330,Inf=1e9;
int n,m1,m2,ans,sum,S,dis[N][N],pw[N],f[M],g[M];
bool v[N][N],vis[N];
queue<int> q;
int main()
{
scanf("%d%d",&n,&m1);
memset(dis,0x3f3f3f3f,sizeof(dis));
for (int i=1,x,y,z;i<=m1;i++)
{
scanf("%d%d%d",&x,&y,&z);
dis[x][y]=dis[y][x]=min(dis[x][y],z);
sum+=z; v[x][y]=v[y][x]=vis[x]=vis[y]=1;
S^=(1<<x-1)^(1<<y-1);
}
vis[1]=1;
scanf("%d",&m2);
for (int i=1,x,y,z;i<=m2;i++)
{
scanf("%d%d%d",&x,&y,&z);
dis[x][y]=dis[y][x]=min(dis[x][y],z);
}
for (int k=1;k<=n;k++)
for (int i=1;i<=n;i++)
for (int j=1;j<=n;j++)
if (i!=j && j!=k && i!=k)
dis[i][j]=min(dis[i][j],dis[i][k]+dis[k][j]);
memset(f,0x3f3f3f3f,sizeof(f));
pw[0]=1;
for (int i=1;i<=n;i++)
{
pw[i]=pw[i-1]*3; f[pw[i-1]*2]=0;
q.push(pw[i-1]*2);
}
while (q.size())
{
int s=q.front(); q.pop();
for (int j=1,ss;j<=n;j++)
if (!((s/pw[j-1])%3))
for (int i=1;i<=n;i++)
if ((s/pw[i-1])%3)
if (!v[i][j])
{
if ((s/pw[i-1])%3==1) ss=s+pw[j-1]+pw[i-1];
else ss=s+pw[j-1]-pw[i-1];
if (f[ss]>Inf) q.push(ss);
f[ss]=min(f[ss],f[s]+dis[i][j]);
}
else
{
ss=s+2*pw[j-1];
if (f[ss]>Inf) q.push(ss);
f[ss]=min(f[ss],f[s]);
}
}
memset(g,0x3f3f3f3f,sizeof(g));
g[0]=0; ans=Inf;
for (int s=0;s<(1<<n);s++)
for (int i=1;i<=n;i++)
if (s&(1<<i-1))
for (int j=i+1;j<=n;j++)
if (s&(1<<j-1))
g[s]=min(g[s],g[s^(1<<i-1)^(1<<j-1)]+dis[i][j]);
for (int s=0;s<pw[n];s++)
{
int ss=0; bool flag=1;
for (int i=0;i<n;i++)
{
if ((s/pw[i])%3==1) ss+=(1<<i);
if (vis[i+1] && !((s/pw[i])%3)) { flag=0; break; }
}
if (flag) ans=min(ans,f[s]+g[ss^S]);
}
cout<<ans+sum;
return 0;
}