题解 题目交流通道
Description
Solution
先考虑没有 \(d_{i,j}\) 为 \(0\) 的情况 . 这时对于一个 \(d_{i,j}\) 若存在 \(d_{i,j}=d_{i,k}+d_{k,j}\), 那么这条边的代价可以在 \(d_{i,j}-k\) 中任取 . 因为不论怎么取都会有 \(d_{i,k}+d_{k,j}\) 来满足它 , 而且因为 \(d_{i,j}>0\), 所以不会有三条边都任选的情况 .
然后考虑存在 \(d_{i,j}\) 为 \(0\), 先把所有距离为 \(0\) 的点缩在一起来处理 , 一个集合内部的情况可以 dp 得出 .
设 \(f(n)\) 为大小为 \(n\) 的集合的方案总数 , 转移考虑枚举节点 \(1\) 所在的联通块大小来容斥 .
\(f(n)=(k+1)^{\binom{n}{2}}-\sum\limits_{i=1}^{n-1}f_i*(k+1)^{\binom{n-i}{2}}*\binom{n-1}{i-1}*k^{i(n-i)}\)
对于集合之间的边的情况 .
设该边为 \(d_{u,v}\) , 数量为 \(a\).
如果存在 \(d_{u,k}+d_{k,v}=d_{u,v}\), 那么这些边的值可以在 \(d_{u,v}-k\) 任选 , 方案数为 \((k-d_{u,v}+1)^a\)
否则至少要有一条边取到 \(d_{u,v}\), 方案数为 \((k-d_{u,v}+1)^a-(k-d_{u,v})^a\)
时间复杂度 \(O(n^3)\)
Code
#include<iostream>
#include<cstdio>
typedef long long ll;
using namespace std;
int read()
{
int ret=0;char c=getchar();
while(c>'9'||c<'0')c=getchar();
while(c>='0'&&c<='9')ret=(ret<<3)+(ret<<1)+(c^48),c=getchar();
return ret;
}
const int maxn=405;
const int mod=998244353;
int n,k;
int ans;
int d[maxn][maxn];
int fac[maxn],inv[maxn],powk[maxn*maxn],powk1[maxn*maxn];
int f[maxn];
int qpow(int a,int b)
{
int ret=1;
for(;b;b>>=1)
{
if(b&1)ret=(ll)ret*a%mod;
a=(ll)a*a%mod;
}
return ret;
}
int C(int n,int m){return (ll)fac[n]*inv[n-m]%mod*inv[m]%mod;}
void prework()
{
powk[0]=1;
for(int i=1;i<=n*n;i++)powk[i]=(ll)powk[i-1]*k%mod;
powk1[0]=1;
for(int i=1;i<=n*n;i++)powk1[i]=(ll)powk1[i-1]*(k+1)%mod;
fac[0]=1;
for(int i=1;i<=n;i++)fac[i]=(ll)fac[i-1]*i%mod;
inv[0]=inv[1]=1;
for(int i=2;i<=n;i++)inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
for(int i=2;i<=n;i++)inv[i]=(ll)inv[i]*inv[i-1]%mod;
f[1]=1;
for(int i=2;i<=n;i++)
{
f[i]=powk1[i*(i-1)/2];
for(int j=1;j<i;j++)f[i]=((ll)f[i]-(ll)f[j]*powk1[(i-j)*(i-j-1)/2]%mod*C(i-1,j-1)%mod*powk[j*(i-j)]%mod+mod)%mod;
}
}
struct dsu
{
int fa[maxn],siz[maxn];bool vis[maxn];
void prework(){for(int i=1;i<=n;i++)fa[i]=i,siz[i]=1;}
int get(int x){return x==fa[x]?x:fa[x]=get(fa[x]);}
void merge(int x,int y)
{
x=get(x);y=get(y);if(x==y)return;
fa[y]=x;siz[x]+=siz[y];
}
void calc(int x)
{
if(vis[x])return;
ans=(ll)ans*f[siz[x]]%mod;
vis[x]=1;
}
}S;
bool used[maxn][maxn];
int main()
{
n=read();k=read();
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)d[i][j]=read();
S.prework();
for(int i=1;i<=n;i++)
{
if(d[i][i]){printf("0");return 0;}
for(int j=1;j<=n;j++)
{
if(!d[i][j])S.merge(i,j);
if(d[i][j]>k||d[i][j]!=d[j][i]){printf("0");return 0;}
for(int k=1;k<=n;k++)if(d[i][k]+d[k][j]<d[i][j]){printf("0");return 0;}
}
}
ans=1;
prework();
for(int i=1;i<=n;i++)
for(int j=i+1;j<=n;j++)
{
if(S.get(i)==S.get(j)||used[S.get(j)][S.get(i)])continue;
used[S.get(j)][S.get(i)]=used[S.get(i)][S.get(j)]=1;
bool flag=0;
for(int l=1;l<=n;l++)
{
if(S.get(l)==S.get(i)||S.get(l)==S.get(j))continue;
if(d[i][l]+d[l][j]==d[i][j]){flag=1;break;}
}
if(flag)ans=(ll)ans*qpow(k-d[i][j]+1,S.siz[S.get(i)]*S.siz[S.get(j)])%mod;
else ans=(ll)ans*(qpow(k-d[i][j]+1,S.siz[S.get(i)]*S.siz[S.get(j)])-qpow(k-d[i][j],S.siz[S.get(i)]*S.siz[S.get(j)])+mod)%mod;
}
for(int i=1;i<=n;i++)S.calc(S.get(i));
printf("%d",ans);
return 0;
}