BZOJ4386[POI2015]Wycieczki——矩阵乘法+倍增
题目描述
给定一张n个点m条边的带权有向图,每条边的边权只可能是1,2,3中的一种。
将所有可能的路径按路径长度排序,请输出第k小的路径的长度,注意路径不一定是简单路径,即可以重复走同一个点。
输入
第一行包含三个整数n,m,k(1<=n<=40,1<=m<=1000,1<=k<=10^18)。
接下来m行,每行三个整数u,v,c(1<=u,v<=n,u不等于v,1<=c<=3),表示从u出发有一条到v的单向边,边长为c。
可能有重边。
输出
包含一行一个正整数,即第k短的路径的长度,如果不存在,输出-1。
样例输入
6 6 11
1 2 1
2 3 2
3 4 2
4 5 1
5 3 1
4 6 3
1 2 1
2 3 2
3 4 2
4 5 1
5 3 1
4 6 3
样例输出
4
提示
长度为1的路径有1->2,5->3,4->5。
长度为2的路径有2->3,3->4,4->5->3。
长度为3的路径有4->6,1->2->3,3->4->5,5->3->4。
长度为4的路径有5->3->4->5。
因为边权有三种,但边数比较多,因此不能拆边。但点数比较少可以把每个点拆成三个点,同一个点拆成的三个点要连上边,这样就能使边权都是1了。
很容易想到用二分答案来求第k短路径,但这是log2,显然过不去,因此可以预处理出矩阵乘法的2i的矩阵,每次像倍增lca一样如果能走这么多步就走,不能走就尝试2i-1的矩阵的答案数。
那么怎么统计答案?
可以建一个原点(0号点)连向所有拆点后的原图节点,再将原点连向自己,这样第一行每个数就是原点到达对应点步数小于等于矩阵幂次的总路径数。
但这样求的是2i-1步数的答案,因此还要记录每个点的出度,统计时将每个答案乘上对应点的出度即可。
因为k比较大,矩阵乘法过程中会爆longlong,对于两个数加起来爆longlong,那么结果一定是负数,实际结果也就一定大于k,矩乘和求答案时判一下即可。
#include<set> #include<map> #include<queue> #include<stack> #include<cmath> #include<bitset> #include<vector> #include<cstdio> #include<cstring> #include<iostream> #include<algorithm> #define ll long long using namespace std; ll a[65][125][125]; ll b[125][125]; ll c[125][125]; int mask; ll ans; ll K; int cnt; int f[45][3]; int v[121]; int n,m; int x,y,z; void multiply(ll a[125][125],ll b[125][125],ll c[125][125]) { for(int i=0;i<=cnt;i++) { for(int j=0;j<=cnt;j++) { c[i][j]=0; for(int k=0;k<=cnt;k++) { if(a[i][k]&&b[k][j]) { if(a[i][k]<0||b[k][j]<0) { c[i][j]=-1; break; } if(a[i][k]>K/b[k][j]) { c[i][j]=-1; break; } c[i][j]+=a[i][k]*b[k][j]; if(c[i][j]<0) { c[i][j]=-1; break; } } } } } } bool check() { ll res=0; for(int i=0;i<=cnt;i++) { if(c[0][i]&&v[i]) { if(c[0][i]<0) { return 0; } if(c[0][i]>K/v[i]) { return 0; } res+=c[0][i]*v[i]; if(res<0) { return 0; } } } return res<K; } int main() { scanf("%d%d%lld",&n,&m,&K); for(int i=1;i<=n;i++) { for(int j=0;j<=2;j++) { f[i][j]=++cnt; } } a[0][0][0]++; for(int i=1;i<=n;i++) { for(int j=0;j<=1;j++) { a[0][f[i][j]][f[i][j+1]]++; } a[0][0][f[i][0]]++; } while(m--) { scanf("%d%d%d",&x,&y,&z); a[0][f[x][z-1]][f[y][0]]++; v[f[x][z-1]]++; } for(mask=0;(1ll<<mask)<=K*3;mask++); mask--; for(int i=1;i<=mask;i++) { multiply(a[i-1],a[i-1],a[i]); } for(int i=0;i<=cnt;i++) { b[i][i]=1; } for(int i=mask;i>=0;i--) { multiply(b,a[i],c); if(check()) { ans|=1ll<<i; for(int j=0;j<=cnt;j++) { for(int k=0;k<=cnt;k++) { b[j][k]=c[j][k]; } } } } ans++; if(ans>K*3) { ans=-1; } printf("%lld",ans); }