题解 treecnt
一眼矩阵树
那么问题在于边权怎么设
考虑这样一种巧妙的 \(\require{enclose}\enclose{horizontalstrike}{\tt Observision}\) 构造
对于一个合法方案,$\forall j\in S_i, \tt{add\ 1\ to}\ $$e_j$
那么一个 \(S\) 对一个合法的生成树权值和的影响是 +与之相关的边数
此时每条边的权值 \(w(i, j)\) 为同时包含 \(i, j\) 的限制个数
那么一个合法方案的 \(\sum e=\sum\max(|S_i|-1, 0)\)
发现这个东西就是生成树权值和能取到的上界
所以问题变成最大生成树计数了!(赛时只有一个人想出来,核情核理
可以在 \(O(n^3+\frac{n^2k}{\omega})\) 复杂度内解决
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 510
#define fir first
#define sec second
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, k;
ll e[N][N];
const ll mod=998244353;
char st[2010][N];
// namespace force{
// ll ans;
// bool vis[N];
// struct edge{int u, v; ll val;}sta[N], tem[N];
// int st2[N], dsu[N], top;
// inline int find(int p) {return dsu[p]==p?p:dsu[p]=find(dsu[p]);}
// bool check(int s) {
// int cnt=0;
// for (int i=1; i<=n; ++i)
// if (s&(1<<i)) vis[i]=1;
// else vis[i]=0;
// for (int i=1; i<n; ++i) if (vis[tem[i].u]&&vis[tem[i].v]) ++cnt;
// return cnt==__builtin_popcount(s)-1;
// }
// void solve() {
// for (int i=1; i<=n; ++i) dsu[i]=i;
// for (int i=1; i<n; ++i) for (int j=i+1; j<=n; ++j) if (e[i][j]) sta[top++]={i, j, e[i][j]}, dsu[find(i)]=find(j);
// int rot=find(1);
// for (int i=1; i<=n; ++i) if (find(i)!=rot) {puts("0"); return ;}
// for (int i=1; i<=k; ++i) for (int j=1; j<=n; ++j) if (st[i][j]=='1') st2[i]|=1<<j;
// ++k; for (int j=1; j<=n; ++j) st2[k]|=1<<j;
// sort(st2+1, st2+k+1);
// k=unique(st2+1, st2+k+1)-st2-1;
// reverse(st2+1, st2+k+1);
// int lim=1<<top; ll sum;
// for (int s=0,cnt; s<lim; ++s) if (__builtin_popcount(s)==n-1) {
// cnt=0;
// for (int i=0; i<top; ++i) if (s&(1<<i)) tem[++cnt]=sta[i];
// for (int i=1; i<=k; ++i) if (!check(st2[i])) goto jump;
// sum=1;
// for (int i=1; i<=cnt; ++i) sum=sum*tem[i].val%mod;
// ans=(ans+sum)%mod;
// jump: ;
// }
// cout<<ans<<endl;
// }
// }
namespace task1{
ll ans;
bool vis[N];
struct edge{int u, v; ll val;}sta[N], tem[N];
int st2[N], dsu[N], top;
inline int find(int p) {return dsu[p]==p?p:dsu[p]=find(dsu[p]);}
bool check(int s) {
int cnt=0;
for (int i=1; i<=n; ++i) dsu[i]=i;
for (int i=1; i<=n; ++i)
if (s&(1<<i)) vis[i]=1;
else vis[i]=0;
for (int i=1; i<n; ++i) if (vis[tem[i].u]&&vis[tem[i].v]) dsu[find(tem[i].u)]=find(tem[i].v);
int rot=0;
for (int i=1; i<=n; ++i) if (s&(1<<i)) {
if (rot) {if (find(i)!=rot) return 0;}
else rot=find(i);
}
return 1;
}
void solve() {
for (int i=1; i<=n; ++i) dsu[i]=i;
for (int i=1; i<n; ++i) for (int j=i+1; j<=n; ++j) if (e[i][j]) sta[top++]={i, j, e[i][j]}, dsu[find(i)]=find(j);
int rot=find(1);
for (int i=1; i<=n; ++i) if (find(i)!=rot) {puts("0"); return ;}
for (int i=1; i<=k; ++i) for (int j=1; j<=n; ++j) if (st[i][j]=='1') st2[i]|=1<<j;
++k; for (int j=1; j<=n; ++j) st2[k]|=1<<j;
sort(st2+1, st2+k+1);
k=unique(st2+1, st2+k+1)-st2-1;
reverse(st2+1, st2+k+1);
int lim=1<<top; ll sum;
for (int s=0,cnt; s<lim; ++s) if (__builtin_popcount(s)==n-1) {
cnt=0;
for (int i=0; i<top; ++i) if (s&(1<<i)) tem[++cnt]=sta[i];
for (int i=1; i<=k; ++i) if (!check(st2[i])) goto jump;
sum=1;
for (int i=1; i<=cnt; ++i) sum=sum*tem[i].val%mod;
ans=(ans+sum)%mod;
jump: ;
}
cout<<ans<<endl;
}
}
namespace task{
ll ans=1, sum;
bitset<2010> bel[N];
int uni[N], dsu[N], id[N], cnt[2010], usiz, tot, ecnt, top;
struct edge{int from, to, val; ll tim;}e[N*N], sta[N*N];
inline int find(int p) {return dsu[p]==p?p:dsu[p]=find(dsu[p]);}
struct matrix{
int n, m;
ll a[N][N];
matrix() {n=m=0; memset(a, 0, sizeof(a));}
matrix(int x, int y) {n=x; m=y; memset(a, 0, sizeof(a));}
void resize(int x, int y) {n=x; m=y; for (int i=1; i<=n; ++i) for (int j=1; j<=m; ++j) a[i][j]=0;}
inline ll* operator [] (int t) {return a[t];}
inline void put() {for (int i=1; i<=n; ++i) {for (int j=1; j<=m; ++j) cout<<setw(2)<<a[i][j]<<' '; cout<<endl;}cout<<endl;}
inline ll gauss() {
ll det=1;
for (int i=1; i<=n; ++i) {
for (int j=i+1; j<=n; ++j) {
while (a[j][i]) {
ll t=a[i][i]/a[j][i];
for (int k=i; k<=m; ++k) a[i][k]=((a[i][k]-a[j][k]*t)%mod+mod)%mod;
swap(a[i], a[j]);
det=-det;
}
}
}
for (int i=1; i<=n; ++i) det=det*a[i][i]%mod;
return det;
}
}mat;
void solve() {
for (int i=1; i<=k; ++i) for (int j=1; j<=n; ++j) if (st[i][j]=='1') bel[j][i]=1, ++cnt[i];
for (int i=1; i<=n; ++i) for (int j=1; j<=n; ++j) e[++ecnt]={i, j, (bel[i]&bel[j]).count(), ::e[i][j]};
sort(e+1, e+ecnt+1, [](edge a, edge b){return a.val>b.val;});
for (int i=1; i<=n; ++i) dsu[i]=i;
for (int i=1,s,t; i<=ecnt; ++i) if ((s=find(e[i].from))!=(t=find(e[i].to))) {
dsu[s]=t;
sta[++top]=e[i];
sum+=uni[++usiz]=e[i].val;
}
for (int i=1; i<=k; ++i) sum-=max(cnt[i]-1, 0);
if (sum!=0) {puts("0"); return ;}
// cout<<"top: "<<top<<endl;
sort(uni+1, uni+usiz+1);
usiz=unique(uni+1, uni+usiz+1)-uni-1;
for (int i=1; i<=usiz; ++i) {
for (int j=1; j<=n; ++j) dsu[j]=j, id[j]=0;
for (int j=1; j<=top; ++j) if (sta[j].val!=uni[i]) dsu[find(sta[j].from)]=find(sta[j].to);
tot=0;
for (int j=1; j<=n; ++j) if (!id[find(j)]) id[find(j)]=++tot;
mat.resize(tot-1, tot-1);
for (int j=1,s,t; j<=ecnt; ++j) if (e[j].val==uni[i]) {
s=id[find(e[j].from)]; t=id[find(e[j].to)];
mat[s][s]=(mat[s][s]+e[j].tim)%mod;
mat[s][t]=(mat[s][t]-e[j].tim)%mod;
}
ans=ans*mat.gauss()%mod;
}
cout<<(ans%mod+mod)%mod<<endl;
}
}
signed main()
{
freopen("treecnt.in", "r", stdin);
freopen("treecnt.out", "w", stdout);
scanf("%d%d", &n, &k);
for (int i=1; i<n; ++i) for (int j=i+1; j<=n; ++j) scanf("%lld", &e[i][j]), e[j][i]=e[i][j];
for (int i=1; i<=k; ++i) scanf("%s", st[i]+1);
// force::solve();
// if (n>50) puts("0");
// else task1::solve();
task::solve();
return 0;
}