题解 题目交流通道

传送门

一条边可以随意取值的条件是存在 \(d[i][j]=d[i][k]+d[k][j]\)
对于权值为零的边,考虑缩点
对方案数的容斥见蓝书 P337

Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 410
#define ll long long
#define fir first
#define sec second
#define make make_pair
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, K;
const ll mod=998244353;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

namespace force{
	int d[20][20], dis[20][20], w[20][20], tot, ans;
	pair<int, int> e[N];
	void check() {
		memset(dis, 127, sizeof(dis));
		for (int i=1; i<=n; ++i) for (int j=1; j<=n; ++j) dis[i][j]=w[i][j];
		for (int k=1; k<=n; ++k)
			for (int i=1; i<=n; ++i) if (i!=k)
				for (int j=1; j<=n; ++j) if (i!=j && j!=k)
					dis[i][j]=min(dis[i][j], dis[i][k]+dis[k][j]);
		#if 0
		cout<<"---w---"<<endl;
		for (int i=1; i<=n; ++i) {for (int j=1; j<=n; ++j) cout<<w[i][j]<<' '; cout<<endl;}
		cout<<endl;
		cout<<"---dis---"<<endl;
		for (int i=1; i<=n; ++i) {for (int j=1; j<=n; ++j) cout<<dis[i][j]<<' '; cout<<endl;}
		cout<<endl;
		#endif
		for (int i=1; i<=n; ++i) for (int j=i+1; j<=n; ++j) if (dis[i][j]!=d[i][j]) return ;
		++ans;
	}
	void dfs(int u) {
		if (u>tot) {check(); return ;}
		for (int i=0; i<=K; ++i) {
			w[e[u].fir][e[u].sec]=w[e[u].sec][e[u].fir]=i;
			dfs(u+1);
		}
	}
	void solve() {
		memset(w, 127, sizeof(w));
		for (int i=1; i<=n; ++i) for (int j=1; j<=n; ++j) d[i][j]=read();
		for (int i=1; i<=n; ++i) for (int j=i+1; j<=n; ++j) e[++tot]=make(i, j);
		dfs(1);
		printf("%d\n", ans);
		exit(0);
	}
}

namespace task{
	ll d[N][N], ans=1, f[N], g[N], dis[N][N], fac[N], inv[N];
	int fa[N], siz[N], top;
	bool vis[N];
	pair<int, int> sta[N];
	inline int find(int p) {return fa[p]==p?p:fa[p]=find(fa[p]);}
	inline ll C(int n, int k) {return fac[n]*inv[k]%mod*inv[n-k]%mod;}
	void solve() {
		fac[0]=fac[1]=1; inv[0]=inv[1]=1;
		for (int i=1; i<=n; ++i) fa[i]=i, siz[i]=1;
		for (int i=1; i<=n; ++i) for (int j=1; j<=n; ++j) d[i][j]=read();
		for (int i=1; i<=n; ++i) if (d[i][i]) {puts("0"); exit(0);}
		for (int i=2; i<=n; ++i) fac[i]=fac[i-1]*i%mod;
		for (int i=2; i<=n; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
		for (int i=2; i<=n; ++i) inv[i]=inv[i-1]*inv[i]%mod;
		for (int i=1; i<=n; ++i) for (int j=1; j<=n; ++j) if (d[i][j]>K || d[i][j]!=d[j][i]) {puts("0"); exit(0);}
		for (int k=1; k<=n; ++k) for (int i=1; i<=n; ++i) for (int j=1; j<n; ++j) if (d[i][j]>d[i][k]+d[k][j]) {puts("0"); exit(0);}
		for (int i=1,f1,f2; i<=n; ++i) for (int j=i+1; j<=n; ++j) if (!d[i][j]) {f1=find(i); f2=find(j); fa[f2]=f1; siz[f1]+=siz[f2];}
		for (int i=1,f; i<=n; ++i) if (!vis[f=find(i)]) {sta[++top]=make(f, siz[f]); vis[f]=1;}
		for (int i=1; i<=n; ++i) g[i]=qpow(K+1, i*(i-1)/2);
		for (int i=1; i<=n; ++i) {f[i]=g[i]; for (int j=1; j<i; ++j) f[i]=(f[i]-f[j]*g[i-j]%mod*C(i-1, j-1)%mod*qpow(K, j*(i-j))%mod)%mod;}
		for (int i=1; i<=top; ++i) ans=ans*f[sta[i].sec]%mod;
		for (int i=1,f1,f2; i<=n; ++i) for (int j=i+1; j<=n; ++j) {f1=find(i); f2=find(j); if (f1!=f2) dis[f1][f2]=dis[f2][f1]=d[i][j];}
		for (int i=1; i<=top; ++i) for (int j=i+1; j<=top; ++j) {
			for (int k=1; k<=top; ++k) if (k!=i && k!=j && dis[sta[i].fir][sta[j].fir]==dis[sta[i].fir][sta[k].fir]+dis[sta[k].fir][sta[j].fir]) {
				ans = ans * qpow(K-dis[sta[i].fir][sta[j].fir]+1, sta[i].sec*sta[j].sec)%mod;
				goto jump;
			}
			ans = ans * (qpow(K-dis[sta[i].fir][sta[j].fir]+1, sta[i].sec*sta[j].sec)-qpow(K-dis[sta[i].fir][sta[j].fir], sta[i].sec*sta[j].sec))%mod;
			jump: ;
		}
		printf("%lld\n", (ans%mod+mod)%mod);
		exit(0);
	}
}

signed main()
{
	freopen("c.in", "r", stdin);
	freopen("c.out", "w", stdout);
	
	n=read(); K=read();
	// force::solve();
	task::solve();
	
	return 0;
}
posted @ 2021-09-26 21:13  Administrator-09  阅读(4)  评论(0编辑  收藏  举报