[省选集训2022] 模拟赛19

矩阵树

题目描述

给定一个 \(n\) 个点的无向完全图,其中边 \((i,j)\) 的个数是 \(a(i,j)\);有 \(k\) 个要求,第 \(i\) 个要求是点集 \(S\) 的导出子图要连通,问满足条件的生成树个数,答案对 \(998244353\) 取模。

\(n\leq 500,k\leq 2000\)

解法

如果没有限制就是裸的矩阵树定理,这其实我们往矩阵树的方向思考。

首先观察 \(S\) 的导出子图连通有什么性质,我们可以将限制转化到边上,那么就相当于有恰好 \(|S|-1\) 条边(忽略 \(|S|=0\) 的情况),满足其的两个端点都在 \(S\) 中。并且有一个关键的 \(\tt observation\):无论合法还是不合法的情况,这样的边数最多有 \(|S|-1\) 个。

那么说明本题的判据跟最值有一定关联了,那么对于全部的 \(k\) 个限制,设 \(w(i,j)\) 表示边 \((i,j)\) 的两个端点都出现在了多少个 \(S\) 中,那么只有最大生成树才可能成为答案

可以用 \(\tt bitset\)\(O(\frac{n^2k}{w})\) 的时间求出每条边的边权,最大生成树计数是经典问题。由于每种边权的数量固定,对于每种边权的每个连通块,我们单独跑矩阵树定理,限制好矩阵大小时间复杂度就是 \(O(n^3)\) 的。

#include <cstdio>
#include <bitset>
#include <iostream>
#include <algorithm>
using namespace std;
const int M = 505;
const int N = 2005;
const int MOD = 998244353;
#define int long long
int read()
{
	int x=0,f=1;char c;
	while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
	while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
	return x*f;
}
int n,k,sum,ans,a[M][M],fa[M],fn[M],vis[M],id[M];
bitset<N> g[M];char s[N];
struct edge{int u,v,c;};vector<edge> e[N],G[M];
struct matrix
{
	int n,a[M][M];
	void clear()
	{
		for(int i=1;i<=n;i++)
			for(int j=1;j<=n;j++)
				a[i][j]=0;
		n=0;
	}
	void add(int u,int v,int c)
	{
        a[u][u]=(a[u][u]+c)%MOD;
		a[v][v]=(a[v][v]+c)%MOD;
        a[u][v]=(a[u][v]+MOD-c)%MOD;
		a[v][u]=(a[v][u]+MOD-c)%MOD;
	}
	int qkpow(int a,int b)
	{
		int r=1;
		while(b>0)
		{
			if(b&1) r=r*a%MOD;
			a=a*a%MOD;
			b>>=1;
		}
		return r;
	}
	int gauss()
	{
		int ans=1;
		for(int i=2;i<=n;i++)
	    {
	        for(int j=i+1;j<=n;j++)
	            if(!a[i][i] && a[j][i])
	            {
	                ans=MOD-ans;
	                swap(a[i],a[j]);
	                break;
	            }
	        ans=ans*a[i][i]%MOD;
	        int inv=qkpow(a[i][i],MOD-2);
	        for(int j=i+1;j<=n;j++)
	        {
	            int tmp=a[j][i]*inv%MOD;
	            for(int k=i;k<=n;k++)
	                a[j][k]=(a[j][k]-a[i][k]*tmp
					%MOD+MOD)%MOD;
	        }
	    }
		return ans;
	}
}z;
int find(int x)
{
	if(x==fa[x]) return x;
	return fa[x]=find(fa[x]);
}
int zxy(int x)
{
	if(x==fn[x]) return x;
	return fn[x]=zxy(fn[x]);
}
void kruskal()
{
	ans=1;
	for(int i=1;i<=n;i++) fa[i]=fn[i]=i;
	for(int i=k;i>=0;i--) if(e[i].size())
	{
		for(edge x:e[i])
			fn[zxy(x.u)]=fn[zxy(x.v)];
		for(edge x:e[i])
		{
			int u=find(x.u),v=find(x.v);
			if(u^v) G[zxy(u)].push_back(x);
		}
		for(int R=1;R<=n;R++) if(G[R].size())
		{
			for(int j=1;j<=n;j++) vis[j]=id[j]=0;
			int m=0;
			for(edge x:G[R])
			{
				int u=find(x.u),v=find(x.v);
				if(!vis[u]) vis[u]=1,id[u]=++m;
				if(!vis[v]) vis[v]=1,id[v]=++m;
				z.add(id[u],id[v],x.c);
			}
			z.n=m;
			ans=ans*z.gauss()%MOD;
			z.clear();
			G[R].clear();
		}
		for(edge x:e[i])
		{
			int u=find(x.u),v=find(x.v);
			if(u^v) sum-=i,fa[u]=v;
		}
	}
	printf("%lld\n",(sum==0)?ans:0); 
}
signed main()
{
	freopen("treecnt.in","r",stdin);
	freopen("treecnt.out","w",stdout);
	n=read();k=read();
	for(int i=1;i<=n;i++)
		for(int j=i+1;j<=n;j++)
			a[i][j]=read();
	for(int i=1;i<=k;i++)
	{
		scanf("%s",s+1);
		int fl=0;
		for(int j=1;j<=n;j++) if(s[j]=='1')
			g[j][i]=1,fl=1,sum++;
		sum-=fl;
	}
	for(int i=1;i<=n;i++)
		for(int j=i+1;j<=n;j++)
		{
			int w=(g[i]&g[j]).count();
			e[w].push_back({i,j,a[i][j]});
		}
	kruskal();
}
posted @ 2022-04-01 19:30  C202044zxy  阅读(160)  评论(2编辑  收藏  举报