ARC167C MST on Line++【题解】
题意:给出 \(n\) 个数 \(a_{1...n}\),选一个排列 \(p_{1...n}\)。对于 \(|i-j|\le k\) 的点 \(i,j\) 之间有一条权为 \(\max(a_{p_i},a_{p_j})\) 的边。对于所有排列 \(p\),求对应图的 \(\text{MST}\) 边权和的总和。
\(1\le k<n\le 5000\)
首先定义一条边的大小比较方式:若权值不同则按权值比;若权值相同按连接的编号较小的点的编号比;若都相同按编号较大的点的编号比。
考虑一条边 \((x,y)\) 的贡献(\(x<y\)),他的权值为 \(\max(a_{p_x},a_{p_y})\),他能被选入 \(\text{MST}\) 当且仅当不存在一条起点为 \(x\)、终点为 \(y\) 的路径,满足路径上的边全部小于他。
这个很好理解,因为如果路径存在大于他的边,完全可以用他来代替这条边。
思考什么情况下不存在这样的路径。
-
对于 \(u(u<x,y-u\le k)\),需满足 \(a_{p_u}>\max(a_{p_x},a_{p_y})\)。如果 \(a_{p_u}\le \max(a_{p_x},a_{p_y})\),那么存在路径 \(x\rightarrow u\rightarrow y\),不合法。
-
若 \(a_{p_x}>a_{p_y}\),对于 \(u(x<u<y)\),需满足 \(a_{p_u} \ge a_{p_x}\)。因为如果 \(a_{p_u}<a_{p_x}\),那么存在路径 \(x\rightarrow u\rightarrow y\),不合法。
显然,我们在此情况下无需找一条边数 \(>2\) 的路径,只需要满足上面两个条件即可。
枚举 \(val=\max(a_{p_x},a_{p_y})\),算出小于 \(val\)、等于 \(val\) 、大于 \(val\) 的数的个数分别有 \(cnt_1,cnt,cnt_2\) 个。
分类讨论。
- \(a_{p_x}=a_{p_y}=val\)
此时需要满足第一个条件。
枚举 \(x\) 前面有多少个这样的 \(u\) 需要满足条件,设为 \(i\),此刻对于 \(y\) 有两种情况:
-
\(x=i+1,\space \space x<y\le k\)
-
\(y\ge k,\space \space y-x+i=k\)
一共有 \((k-i-1)+(n-k)=n-i-1\) 种情况。
从 \(cnt\) 中选两个数 \(a_{p_x},a_{p_y}\) 排列有 \(A(cnt,2)\) 种,从 \(cnt_2\) 中选 \(i\) 个数排列有 \(A(cnt_2,i)\) 种,其余随便排有 \((n-i-2)!\) 种。
- \(a_{p_x}<a_{p_y}=val\)
此时需要满足第一个条件,统计方法其实跟 \(a_{p_x}=a_{p_y}\) 是类似的,只不过选择 \(a_{p_x},a_{p_y}\) 两个数的方法变为了 \(cnt\times cnt_1\)。
- \(a_{p_x}>a_{p_y}\)
此时两个条件都需要满足。
分两种情况:
- \(y>k\)
枚举 \(i\) 表示 \(x\) 前面需要满足条件 \(1\) 的数个数,不难发现 \(y-x=k-i\),即 \(x,y\) 中间有 \(k-i-1\) 个数。\(y\) 的取值有 \(n-k\) 种,选择 \(a_{p_x},a_{p_y}\) 的方案数为 \(cnt\times cnt_1\),从 \(cnt_2\) 个中选 \(i\) 个数排列有 \(A(cnt_2,i)\) 种,从 \((cnt_2+cnt-i-1)\) 个(去掉 \(a_{p_y}\) )中选 \((k-i-1)\) 个数排列有 \(A(cnt+cnt_2-i-1,k-i-1)\),其余随便 \((n-k-1)!\) 。
- \(y\le k\)
我们需要枚举两个:枚举 \(i\) 表示 \(x\) 前面需要满足条件 \(1\) 的数的个数,\(j\) 表示 \(x,y\) 中间的数个数。
那么 \(x=i+1,\space\space y=x+j+1=i+j+2\)。
类似上面的式子,有
化一下
设 \(g=cnt\times cnt_1\times A(cnt_2,i)\times(cnt+cnt_2-i-1)!\)。
设 \(w[u,v]=\sum_{j=0}^{\min(u,v)} \frac{(v-j)!}{(u-j)!}\),那么
\(w[,]\) 可以 \(O(n^2)\) 预处理,时间复杂度 \(O(n^2)\)。
#include<bits/stdc++.h>
#define ll long long
#define max(a,b) ((a)>(b)? (a):(b))
#define min(a,b) ((a)<(b)? (a):(b))
#define pir pair<ll,ll>
#define mkp make_pair
#define pb push_back
#define ad(a,b) (a=(a+b>=mod? a+b-mod:a+b))
using namespace std;
const ll maxn=5010,mod=998244353;
ll n,m,a[maxn],ans,fac[maxn],inv[maxn],w[maxn][maxn];
ll A(ll n,ll m)
{
if(n<m) return 0;
return fac[n]*inv[n-m]%mod;
}
void solve(ll val)
{
ll cnt=0, cnt1=0, cnt2=0;
ll res=0;
for(ll i=1;i<=n;i++) cnt+=(a[i]==val), cnt1+=(a[i]<val), cnt2+=(a[i]>val);
for(ll i=0;i<m;i++)
{
res=(res+(m-i-1 + n-m)*A(cnt,2)%mod*A(cnt2,i)%mod*fac[n-i-2])%mod;
}
for(ll i=0;i<m;i++)
{
res=(res+(m-i-1 + n-m)*cnt%mod*cnt1%mod*A(cnt2,i)%mod*fac[n-i-2])%mod;
}
for(ll i=0;i<m&&i<cnt2+cnt;i++)
{
res=(res+(n-m)*cnt%mod*cnt1%mod*A(cnt2,i)%mod*A(cnt2+cnt-i-1,m-1-i)%mod*fac[n-m-1])%mod;
ll rs=cnt*cnt1%mod*A(cnt2,i)%mod*fac[cnt2+cnt-i-1]%mod, tc=cnt2+cnt-i-1, nk=n-i-2;
ll r=(w[tc][nk]-(m-2-i<min(tc,nk)? w[tc-(m-2-i)-1][nk-(m-2-i)-1]:0)+mod)%mod;
res=(res+r*rs)%mod;
}
ans=(ans+res*val)%mod;
}
int main()
{
scanf("%lld%lld",&n,&m);
for(ll i=1;i<=n;i++) scanf("%lld",a+i);
fac[0]=inv[0]=fac[1]=inv[1]=1;
for(ll i=2;i<=n;i++)
{
fac[i]=fac[i-1]*i%mod;
inv[i]=(mod-mod/i)*inv[mod%i]%mod;
}
for(ll i=2;i<=n;i++) inv[i]=inv[i-1]*inv[i]%mod;
for(ll i=0;i<=n;i++)
{
for(ll j=0;j<=n;j++)
{
ll v=inv[i]*fac[j]%mod;
if(i&&j) w[i][j]=(w[i-1][j-1]+v)%mod;
else w[i][j]=v;
}
}
sort(a+1,a+1+n);
for(ll i=1;i<=n;i++)
{
if(a[i]!=a[i-1])
{
solve(a[i]);
}
}
printf("%lld",ans);
return 0;
}