题解 treecnt
Description
\(n\leq 500,k\leq 2000\)
Solution
令 \(w(i,j)\) 表示同时含有 \(i,j\) 两个点的限制的数量 , 在 \(i,j\) 之间连一条权值为 \(w(i,j)\) 的边 .
那么可以观察到 , 每个限制最多对生成树总边权贡献 \(|S|-1\) , 同时必须贡献 \(|S|-1\) 才是一个合法的生成树 , 那么对这个图做最大生成树计数即可 .
\(w(i,j)\) 可以使用 bitset 求出 .
时间复杂度 \(\displaystyle O(n^3+\frac{n^2k}{w})\)
code
#include<bits/stdc++.h>
typedef long long ll;
using namespace std;
int read()
{
int ret=0;bool f=0;char c=getchar();
while(c>'9'||c<'0')f|=(c=='-'),c=getchar();
while(c>='0'&&c<='9')ret=(ret<<3)+(ret<<1)+(c^48),c=getchar();
return f?-ret:ret;
}
const int maxn=505,maxm=2000,mod=998244353;
int qpow(int a,int b){int ret=1;for(;b;b>>=1,a=(ll)a*a%mod)if(b&1)ret=(ll)ret*a%mod;return ret;}
int n,k;
int val[maxn][maxn];
bitset<maxm>exi[maxn];
struct edge{int fr,to,val;}e[maxn*maxn];int cnt;
struct dsu
{
int fa[maxn];
void prework(){iota(fa+1,fa+n+1,1);}
int get(int x){return x==fa[x]?x:fa[x]=get(fa[x]);}
void merge(int x,int y){fa[get(x)]=get(y);}
bool check(int x,int y){return get(x)==get(y);}
}S;
struct matrix
{
int num[maxn][maxn];
void clear(){memset(num,0,sizeof num);}
int*operator[](const int &i){return num[i];}
int calc(int n)
{
int ret=1;
for(int i=1;i<=n;i++)
{
int pos=-1;
for(int j=i;j<=n;j++)if(num[j][i]){pos=j;break;}
if(pos==-1)return 0;
if(pos!=i){swap(num[pos],num[i]);ret=mod-ret;}
ret=(ll)ret*num[i][i]%mod;
int inv=qpow(num[i][i],mod-2);
for(int j=i;j<=n;j++)num[i][j]=(ll)num[i][j]*inv%mod;
for(int j=i+1;j<=n;j++)
if(num[j][i])
{
int tmp=num[j][i];
for(int k=i;k<=n;k++)
(num[j][k]+=mod-(ll)num[i][k]*tmp%mod)%=mod;
}
}
return ret;
}
}A;
int get()
{
int ret=1;
S.prework();
for(int i=1;i<=cnt;i++)
{
if(i!=1&&e[i].val==e[i-1].val)continue;
map<int,int>id;int num=0;A.clear();
for(int j=i;j<=cnt;j++)
{
if(e[j].val!=e[i].val)break;
if(S.check(e[j].fr,e[j].to))continue;
if(!id.count(S.get(e[j].fr)))id[S.get(e[j].fr)]=++num;
if(!id.count(S.get(e[j].to)))id[S.get(e[j].to)]=++num;
(A[id[S.get(e[j].fr)]][id[S.get(e[j].to)]]+=mod-val[e[j].fr][e[j].to])%=mod;
(A[id[S.get(e[j].to)]][id[S.get(e[j].fr)]]+=mod-val[e[j].fr][e[j].to])%=mod;
(A[id[S.get(e[j].to)]][id[S.get(e[j].to)]]+=val[e[j].fr][e[j].to])%=mod;
(A[id[S.get(e[j].fr)]][id[S.get(e[j].fr)]]+=val[e[j].fr][e[j].to])%=mod;
}
for(int j=i;j<=cnt;j++)
{
if(e[j].val!=e[i].val)break;
S.merge(e[j].fr,e[j].to);
}
static bool vis[maxn];
for(auto &j:id)
if(!vis[S.get(j.first)])
{
vis[S.get(j.first)]=1;
A[j.second][j.second]++;
}
for(auto &j:id)vis[S.get(j.first)]=0;
ret=(ll)ret*A.calc(num)%mod;
}
return ret;
}
int main()
{
freopen("treecnt.in","r",stdin);
freopen("treecnt.out","w",stdout);
n=read();k=read();
for(int i=1;i<n;i++)
for(int j=i+1;j<=n;j++)val[i][j]=val[j][i]=read();
int sum=0;
for(int i=0;i<k;i++)
{
string s;cin>>s;
bool flag=0;
for(int j=1;j<=n;j++)
if(s[j-1]=='1')exi[j][i]=1,sum++;
if(s.find('1')!=s.npos)sum--;
}
for(int i=1;i<=n;i++)
for(int j=i+1;j<=n;j++)
e[++cnt]={i,j,(int)((exi[i]&exi[j]).count())};
sort(e+1,e+cnt+1,[](edge &i,edge &j){return i.val>j.val;});
S.prework();
for(int i=1;i<=cnt;i++)
if(!S.check(e[i].fr,e[i].to))
{
sum-=e[i].val;
S.merge(e[i].fr,e[i].to);
}
if(sum){puts("0");return 0;}
printf("%d\n",get());
return 0;
}