题解 题目交流通道
一条边可以随意取值的条件是存在 \(d[i][j]=d[i][k]+d[k][j]\)
对于权值为零的边,考虑缩点
对方案数的容斥见蓝书 P337
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 410
#define ll long long
#define fir first
#define sec second
#define make make_pair
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, K;
const ll mod=998244353;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
namespace force{
int d[20][20], dis[20][20], w[20][20], tot, ans;
pair<int, int> e[N];
void check() {
memset(dis, 127, sizeof(dis));
for (int i=1; i<=n; ++i) for (int j=1; j<=n; ++j) dis[i][j]=w[i][j];
for (int k=1; k<=n; ++k)
for (int i=1; i<=n; ++i) if (i!=k)
for (int j=1; j<=n; ++j) if (i!=j && j!=k)
dis[i][j]=min(dis[i][j], dis[i][k]+dis[k][j]);
#if 0
cout<<"---w---"<<endl;
for (int i=1; i<=n; ++i) {for (int j=1; j<=n; ++j) cout<<w[i][j]<<' '; cout<<endl;}
cout<<endl;
cout<<"---dis---"<<endl;
for (int i=1; i<=n; ++i) {for (int j=1; j<=n; ++j) cout<<dis[i][j]<<' '; cout<<endl;}
cout<<endl;
#endif
for (int i=1; i<=n; ++i) for (int j=i+1; j<=n; ++j) if (dis[i][j]!=d[i][j]) return ;
++ans;
}
void dfs(int u) {
if (u>tot) {check(); return ;}
for (int i=0; i<=K; ++i) {
w[e[u].fir][e[u].sec]=w[e[u].sec][e[u].fir]=i;
dfs(u+1);
}
}
void solve() {
memset(w, 127, sizeof(w));
for (int i=1; i<=n; ++i) for (int j=1; j<=n; ++j) d[i][j]=read();
for (int i=1; i<=n; ++i) for (int j=i+1; j<=n; ++j) e[++tot]=make(i, j);
dfs(1);
printf("%d\n", ans);
exit(0);
}
}
namespace task{
ll d[N][N], ans=1, f[N], g[N], dis[N][N], fac[N], inv[N];
int fa[N], siz[N], top;
bool vis[N];
pair<int, int> sta[N];
inline int find(int p) {return fa[p]==p?p:fa[p]=find(fa[p]);}
inline ll C(int n, int k) {return fac[n]*inv[k]%mod*inv[n-k]%mod;}
void solve() {
fac[0]=fac[1]=1; inv[0]=inv[1]=1;
for (int i=1; i<=n; ++i) fa[i]=i, siz[i]=1;
for (int i=1; i<=n; ++i) for (int j=1; j<=n; ++j) d[i][j]=read();
for (int i=1; i<=n; ++i) if (d[i][i]) {puts("0"); exit(0);}
for (int i=2; i<=n; ++i) fac[i]=fac[i-1]*i%mod;
for (int i=2; i<=n; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
for (int i=2; i<=n; ++i) inv[i]=inv[i-1]*inv[i]%mod;
for (int i=1; i<=n; ++i) for (int j=1; j<=n; ++j) if (d[i][j]>K || d[i][j]!=d[j][i]) {puts("0"); exit(0);}
for (int k=1; k<=n; ++k) for (int i=1; i<=n; ++i) for (int j=1; j<n; ++j) if (d[i][j]>d[i][k]+d[k][j]) {puts("0"); exit(0);}
for (int i=1,f1,f2; i<=n; ++i) for (int j=i+1; j<=n; ++j) if (!d[i][j]) {f1=find(i); f2=find(j); fa[f2]=f1; siz[f1]+=siz[f2];}
for (int i=1,f; i<=n; ++i) if (!vis[f=find(i)]) {sta[++top]=make(f, siz[f]); vis[f]=1;}
for (int i=1; i<=n; ++i) g[i]=qpow(K+1, i*(i-1)/2);
for (int i=1; i<=n; ++i) {f[i]=g[i]; for (int j=1; j<i; ++j) f[i]=(f[i]-f[j]*g[i-j]%mod*C(i-1, j-1)%mod*qpow(K, j*(i-j))%mod)%mod;}
for (int i=1; i<=top; ++i) ans=ans*f[sta[i].sec]%mod;
for (int i=1,f1,f2; i<=n; ++i) for (int j=i+1; j<=n; ++j) {f1=find(i); f2=find(j); if (f1!=f2) dis[f1][f2]=dis[f2][f1]=d[i][j];}
for (int i=1; i<=top; ++i) for (int j=i+1; j<=top; ++j) {
for (int k=1; k<=top; ++k) if (k!=i && k!=j && dis[sta[i].fir][sta[j].fir]==dis[sta[i].fir][sta[k].fir]+dis[sta[k].fir][sta[j].fir]) {
ans = ans * qpow(K-dis[sta[i].fir][sta[j].fir]+1, sta[i].sec*sta[j].sec)%mod;
goto jump;
}
ans = ans * (qpow(K-dis[sta[i].fir][sta[j].fir]+1, sta[i].sec*sta[j].sec)-qpow(K-dis[sta[i].fir][sta[j].fir], sta[i].sec*sta[j].sec))%mod;
jump: ;
}
printf("%lld\n", (ans%mod+mod)%mod);
exit(0);
}
}
signed main()
{
freopen("c.in", "r", stdin);
freopen("c.out", "w", stdout);
n=read(); K=read();
// force::solve();
task::solve();
return 0;
}