题解 treecnt

传送门

一眼矩阵树
那么问题在于边权怎么设
考虑这样一种巧妙的 \(\require{enclose}\enclose{horizontalstrike}{\tt Observision}\) 构造
对于一个合法方案,$\forall j\in S_i, \tt{add\ 1\ to}\ $$e_j$
那么一个 \(S\) 对一个合法的生成树权值和的影响是 +与之相关的边数
此时每条边的权值 \(w(i, j)\) 为同时包含 \(i, j\) 的限制个数
那么一个合法方案的 \(\sum e=\sum\max(|S_i|-1, 0)\)
发现这个东西就是生成树权值和能取到的上界
所以问题变成最大生成树计数了!(赛时只有一个人想出来,核情核理
可以在 \(O(n^3+\frac{n^2k}{\omega})\) 复杂度内解决

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 510
#define fir first
#define sec second
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, k;
ll e[N][N];
const ll mod=998244353;
char st[2010][N];

// namespace force{
// 	ll ans;
// 	bool vis[N];
// 	struct edge{int u, v; ll val;}sta[N], tem[N];
// 	int st2[N], dsu[N], top;
// 	inline int find(int p) {return dsu[p]==p?p:dsu[p]=find(dsu[p]);}
// 	bool check(int s) {
// 		int cnt=0;
// 		for (int i=1; i<=n; ++i)
// 			if (s&(1<<i)) vis[i]=1;
// 			else vis[i]=0;
// 		for (int i=1; i<n; ++i) if (vis[tem[i].u]&&vis[tem[i].v]) ++cnt;
// 		return cnt==__builtin_popcount(s)-1;
// 	}
// 	void solve() {
// 		for (int i=1; i<=n; ++i) dsu[i]=i;
// 		for (int i=1; i<n; ++i) for (int j=i+1; j<=n; ++j) if (e[i][j]) sta[top++]={i, j, e[i][j]}, dsu[find(i)]=find(j);
// 		int rot=find(1);
// 		for (int i=1; i<=n; ++i) if (find(i)!=rot) {puts("0"); return ;}
// 		for (int i=1; i<=k; ++i) for (int j=1; j<=n; ++j) if (st[i][j]=='1') st2[i]|=1<<j;
// 		++k; for (int j=1; j<=n; ++j) st2[k]|=1<<j;
// 		sort(st2+1, st2+k+1);
// 		k=unique(st2+1, st2+k+1)-st2-1;
// 		reverse(st2+1, st2+k+1);
// 		int lim=1<<top; ll sum;
// 		for (int s=0,cnt; s<lim; ++s) if (__builtin_popcount(s)==n-1) {
// 			cnt=0;
// 			for (int i=0; i<top; ++i) if (s&(1<<i)) tem[++cnt]=sta[i];
// 			for (int i=1; i<=k; ++i) if (!check(st2[i])) goto jump;
// 			sum=1;
// 			for (int i=1; i<=cnt; ++i) sum=sum*tem[i].val%mod;
// 			ans=(ans+sum)%mod;
// 			jump: ;
// 		}
// 		cout<<ans<<endl;
// 	}
// }

namespace task1{
	ll ans;
	bool vis[N];
	struct edge{int u, v; ll val;}sta[N], tem[N];
	int st2[N], dsu[N], top;
	inline int find(int p) {return dsu[p]==p?p:dsu[p]=find(dsu[p]);}
	bool check(int s) {
		int cnt=0;
		for (int i=1; i<=n; ++i) dsu[i]=i;
		for (int i=1; i<=n; ++i)
			if (s&(1<<i)) vis[i]=1;
			else vis[i]=0;
		for (int i=1; i<n; ++i) if (vis[tem[i].u]&&vis[tem[i].v]) dsu[find(tem[i].u)]=find(tem[i].v);
		int rot=0;
		for (int i=1; i<=n; ++i) if (s&(1<<i)) {
			if (rot) {if (find(i)!=rot) return 0;}
			else rot=find(i);
		}
		return 1;
	}
	void solve() {
		for (int i=1; i<=n; ++i) dsu[i]=i;
		for (int i=1; i<n; ++i) for (int j=i+1; j<=n; ++j) if (e[i][j]) sta[top++]={i, j, e[i][j]}, dsu[find(i)]=find(j);
		int rot=find(1);
		for (int i=1; i<=n; ++i) if (find(i)!=rot) {puts("0"); return ;}
		for (int i=1; i<=k; ++i) for (int j=1; j<=n; ++j) if (st[i][j]=='1') st2[i]|=1<<j;
		++k; for (int j=1; j<=n; ++j) st2[k]|=1<<j;
		sort(st2+1, st2+k+1);
		k=unique(st2+1, st2+k+1)-st2-1;
		reverse(st2+1, st2+k+1);
		int lim=1<<top; ll sum;
		for (int s=0,cnt; s<lim; ++s) if (__builtin_popcount(s)==n-1) {
			cnt=0;
			for (int i=0; i<top; ++i) if (s&(1<<i)) tem[++cnt]=sta[i];
			for (int i=1; i<=k; ++i) if (!check(st2[i])) goto jump;
			sum=1;
			for (int i=1; i<=cnt; ++i) sum=sum*tem[i].val%mod;
			ans=(ans+sum)%mod;
			jump: ;
		}
		cout<<ans<<endl;
	}
}

namespace task{
	ll ans=1, sum;
	bitset<2010> bel[N];
	int uni[N], dsu[N], id[N], cnt[2010], usiz, tot, ecnt, top;
	struct edge{int from, to, val; ll tim;}e[N*N], sta[N*N];
	inline int find(int p) {return dsu[p]==p?p:dsu[p]=find(dsu[p]);}
	struct matrix{
		int n, m;
		ll a[N][N];
		matrix() {n=m=0; memset(a, 0, sizeof(a));}
		matrix(int x, int y) {n=x; m=y; memset(a, 0, sizeof(a));}
		void resize(int x, int y) {n=x; m=y; for (int i=1; i<=n; ++i) for (int j=1; j<=m; ++j) a[i][j]=0;}
		inline ll* operator [] (int t) {return a[t];}
		inline void put() {for (int i=1; i<=n; ++i) {for (int j=1; j<=m; ++j) cout<<setw(2)<<a[i][j]<<' '; cout<<endl;}cout<<endl;}
		inline ll gauss() {
			ll det=1;
			for (int i=1; i<=n; ++i) {
				for (int j=i+1; j<=n; ++j) {
					while (a[j][i]) {
						ll t=a[i][i]/a[j][i];
						for (int k=i; k<=m; ++k) a[i][k]=((a[i][k]-a[j][k]*t)%mod+mod)%mod;
						swap(a[i], a[j]);
						det=-det;
					}
				}
			}
			for (int i=1; i<=n; ++i) det=det*a[i][i]%mod;
			return det;
		}
	}mat;
	void solve() {
		for (int i=1; i<=k; ++i) for (int j=1; j<=n; ++j) if (st[i][j]=='1') bel[j][i]=1, ++cnt[i];
		for (int i=1; i<=n; ++i) for (int j=1; j<=n; ++j) e[++ecnt]={i, j, (bel[i]&bel[j]).count(), ::e[i][j]};
		sort(e+1, e+ecnt+1, [](edge a, edge b){return a.val>b.val;});
		for (int i=1; i<=n; ++i) dsu[i]=i;
		for (int i=1,s,t; i<=ecnt; ++i) if ((s=find(e[i].from))!=(t=find(e[i].to))) {
			dsu[s]=t;
			sta[++top]=e[i];
			sum+=uni[++usiz]=e[i].val;
		}
		for (int i=1; i<=k; ++i) sum-=max(cnt[i]-1, 0);
		if (sum!=0) {puts("0"); return ;}
		// cout<<"top: "<<top<<endl;
		sort(uni+1, uni+usiz+1);
		usiz=unique(uni+1, uni+usiz+1)-uni-1;
		for (int i=1; i<=usiz; ++i) {
			for (int j=1; j<=n; ++j) dsu[j]=j, id[j]=0;
			for (int j=1; j<=top; ++j) if (sta[j].val!=uni[i]) dsu[find(sta[j].from)]=find(sta[j].to);
			tot=0;
			for (int j=1; j<=n; ++j) if (!id[find(j)]) id[find(j)]=++tot;
			mat.resize(tot-1, tot-1);
			for (int j=1,s,t; j<=ecnt; ++j) if (e[j].val==uni[i]) {
				s=id[find(e[j].from)]; t=id[find(e[j].to)];
				mat[s][s]=(mat[s][s]+e[j].tim)%mod;
				mat[s][t]=(mat[s][t]-e[j].tim)%mod;
			}
			ans=ans*mat.gauss()%mod;
		}
		cout<<(ans%mod+mod)%mod<<endl;
	}
}

signed main()
{
	freopen("treecnt.in", "r", stdin);
	freopen("treecnt.out", "w", stdout);

	scanf("%d%d", &n, &k);
	for (int i=1; i<n; ++i) for (int j=i+1; j<=n; ++j) scanf("%lld", &e[i][j]), e[j][i]=e[i][j];
	for (int i=1; i<=k; ++i) scanf("%s", st[i]+1);
	// force::solve();
	// if (n>50) puts("0");
	// else task1::solve();
	task::solve();

	return 0;
}
posted @ 2022-04-01 18:05  Administrator-09  阅读(2)  评论(0编辑  收藏  举报