Loading

ARC167C MST on Line++【题解】

题意:给出 \(n\) 个数 \(a_{1...n}\),选一个排列 \(p_{1...n}\)。对于 \(|i-j|\le k\) 的点 \(i,j\) 之间有一条权为 \(\max(a_{p_i},a_{p_j})\) 的边。对于所有排列 \(p\),求对应图的 \(\text{MST}\) 边权和的总和。

\(1\le k<n\le 5000\)


首先定义一条边的大小比较方式:若权值不同则按权值比;若权值相同按连接的编号较小的点的编号比;若都相同按编号较大的点的编号比。

考虑一条边 \((x,y)\) 的贡献(\(x<y\)),他的权值为 \(\max(a_{p_x},a_{p_y})\),他能被选入 \(\text{MST}\) 当且仅当不存在一条起点为 \(x\)、终点为 \(y\) 的路径,满足路径上的边全部小于他。

这个很好理解,因为如果路径存在大于他的边,完全可以用他来代替这条边。

思考什么情况下不存在这样的路径。

  • 对于 \(u(u<x,y-u\le k)\),需满足 \(a_{p_u}>\max(a_{p_x},a_{p_y})\)。如果 \(a_{p_u}\le \max(a_{p_x},a_{p_y})\),那么存在路径 \(x\rightarrow u\rightarrow y\),不合法。

  • \(a_{p_x}>a_{p_y}\),对于 \(u(x<u<y)\),需满足 \(a_{p_u} \ge a_{p_x}\)。因为如果 \(a_{p_u}<a_{p_x}\),那么存在路径 \(x\rightarrow u\rightarrow y\),不合法。

显然,我们在此情况下无需找一条边数 \(>2\) 的路径,只需要满足上面两个条件即可。

枚举 \(val=\max(a_{p_x},a_{p_y})\),算出小于 \(val\)、等于 \(val\) 、大于 \(val\) 的数的个数分别有 \(cnt_1,cnt,cnt_2\) 个。

分类讨论。

  • \(a_{p_x}=a_{p_y}=val\)

此时需要满足第一个条件。

枚举 \(x\) 前面有多少个这样的 \(u\) 需要满足条件,设为 \(i\),此刻对于 \(y\) 有两种情况:

  1. \(x=i+1,\space \space x<y\le k\)

  2. \(y\ge k,\space \space y-x+i=k\)

一共有 \((k-i-1)+(n-k)=n-i-1\) 种情况。

\(cnt\) 中选两个数 \(a_{p_x},a_{p_y}\) 排列有 \(A(cnt,2)\) 种,从 \(cnt_2\) 中选 \(i\) 个数排列有 \(A(cnt_2,i)\) 种,其余随便排有 \((n-i-2)!\) 种。

\[(n-i-1)\times A(cnt,2)\times A(cnt_2,i),\times (n-i-2)! \]

  • \(a_{p_x}<a_{p_y}=val\)

此时需要满足第一个条件,统计方法其实跟 \(a_{p_x}=a_{p_y}\) 是类似的,只不过选择 \(a_{p_x},a_{p_y}\) 两个数的方法变为了 \(cnt\times cnt_1\)

\[(n-i-1)\times cnt\times cnt_1\times A(cnt_2,i),\times (n-i-2)! \]

  • \(a_{p_x}>a_{p_y}\)

此时两个条件都需要满足。

分两种情况:

  • \(y>k\)

枚举 \(i\) 表示 \(x\) 前面需要满足条件 \(1\) 的数个数,不难发现 \(y-x=k-i\),即 \(x,y\) 中间有 \(k-i-1\) 个数。\(y\) 的取值有 \(n-k\) 种,选择 \(a_{p_x},a_{p_y}\) 的方案数为 \(cnt\times cnt_1\),从 \(cnt_2\) 个中选 \(i\) 个数排列有 \(A(cnt_2,i)\) 种,从 \((cnt_2+cnt-i-1)\) 个(去掉 \(a_{p_y}\) )中选 \((k-i-1)\) 个数排列有 \(A(cnt+cnt_2-i-1,k-i-1)\),其余随便 \((n-k-1)!\)

\[(n-k)\times cnt\times cnt_1\times A(cnt_2,i)\times A(cnt+cnt_2-i-1,k-i-1)\times (n-k-1)! \]

  • \(y\le k\)

我们需要枚举两个:枚举 \(i\) 表示 \(x\) 前面需要满足条件 \(1\) 的数的个数,\(j\) 表示 \(x,y\) 中间的数个数。

那么 \(x=i+1,\space\space y=x+j+1=i+j+2\)

类似上面的式子,有

\[\sum_{j=0}^{k-i-2} cnt\times cnt_2\times A(cnt_2,i)\times A(cnt+cnt_2-i-1,j)\times (n-i-j-2)! \]

化一下

\[\begin{aligned} \text{原式} &= cnt\times cnt_1\times A(cnt_2,i)\times \sum_{j=0}^{k-i-2}A(cnt+cnt_2-i-1,j)\times (n-i-j-2)! \\ &= cnt\times cnt_1 \times A(cnt_2,i)\times \sum_{j=0}^{k-i-2} \frac {(cnt+cnt_2-i-1)!}{(cnt+cnt_2-i-j-1)!} \times (n-i-j-2)! \\ &= cnt\times cnt_1\times A(cnt_2,i)\times(cnt+cnt_2-i-1)!\times \sum_{j=0}^{k-i-2} \frac{(n-i-j-2)!}{(cnt+cnt_2-i-j-1)!} \end{aligned} \]

\(g=cnt\times cnt_1\times A(cnt_2,i)\times(cnt+cnt_2-i-1)!\)

\(w[u,v]=\sum_{j=0}^{\min(u,v)} \frac{(v-j)!}{(u-j)!}\),那么

\[\begin{aligned} \text{原式} &= g\times (w[cnt+cnt_2-i-1,n-i-2]-w[(cnt+cnt_2-i-1)-(k-i-2)-1,(n-i-2)-(k-i-2)-1]) \end{aligned} \]

\(w[,]\) 可以 \(O(n^2)\) 预处理,时间复杂度 \(O(n^2)\)

#include<bits/stdc++.h>
#define ll long long
#define max(a,b) ((a)>(b)? (a):(b))
#define min(a,b) ((a)<(b)? (a):(b))
#define pir pair<ll,ll>
#define mkp make_pair
#define pb push_back 
#define ad(a,b) (a=(a+b>=mod? a+b-mod:a+b))
using namespace std;
const ll maxn=5010,mod=998244353;
ll n,m,a[maxn],ans,fac[maxn],inv[maxn],w[maxn][maxn];
ll A(ll n,ll m)
{
	if(n<m) return 0;
	return fac[n]*inv[n-m]%mod;
}
void solve(ll val)
{
	ll cnt=0, cnt1=0, cnt2=0;
	ll res=0;
	for(ll i=1;i<=n;i++) cnt+=(a[i]==val), cnt1+=(a[i]<val), cnt2+=(a[i]>val);
	for(ll i=0;i<m;i++)
	{
		res=(res+(m-i-1 + n-m)*A(cnt,2)%mod*A(cnt2,i)%mod*fac[n-i-2])%mod;
	}
	
	for(ll i=0;i<m;i++)
	{
		res=(res+(m-i-1 + n-m)*cnt%mod*cnt1%mod*A(cnt2,i)%mod*fac[n-i-2])%mod;
	}
	
	for(ll i=0;i<m&&i<cnt2+cnt;i++)
	{
		res=(res+(n-m)*cnt%mod*cnt1%mod*A(cnt2,i)%mod*A(cnt2+cnt-i-1,m-1-i)%mod*fac[n-m-1])%mod;
		ll rs=cnt*cnt1%mod*A(cnt2,i)%mod*fac[cnt2+cnt-i-1]%mod, tc=cnt2+cnt-i-1, nk=n-i-2;
		ll r=(w[tc][nk]-(m-2-i<min(tc,nk)? w[tc-(m-2-i)-1][nk-(m-2-i)-1]:0)+mod)%mod;
		res=(res+r*rs)%mod;
	}
	ans=(ans+res*val)%mod;
}
int main()
{
	scanf("%lld%lld",&n,&m);
	for(ll i=1;i<=n;i++) scanf("%lld",a+i);
	fac[0]=inv[0]=fac[1]=inv[1]=1;
	for(ll i=2;i<=n;i++)
	{
		fac[i]=fac[i-1]*i%mod;
		inv[i]=(mod-mod/i)*inv[mod%i]%mod;
	}
	for(ll i=2;i<=n;i++) inv[i]=inv[i-1]*inv[i]%mod;
	for(ll i=0;i<=n;i++)
	{
		for(ll j=0;j<=n;j++)
		{
			ll v=inv[i]*fac[j]%mod;
			if(i&&j) w[i][j]=(w[i-1][j-1]+v)%mod;
			else w[i][j]=v;
		}
	}
	sort(a+1,a+1+n);
	for(ll i=1;i<=n;i++)
	{
		if(a[i]!=a[i-1])
		{
			solve(a[i]);
		}
	} 
	printf("%lld",ans);
	return 0;
}
posted @ 2023-10-16 15:51  Lgx_Q  阅读(82)  评论(0编辑  收藏  举报