2021牛客寒假算法基础集训营5 B. 比武招亲(上)(排列组合)
链接:https://ac.nowcoder.com/acm/contest/9985/B
来源:牛客网
题目描述
众所周知,天姐姐只喜欢天下最聪明的人,为了找到这样的人,她决定比武招亲!
只见天姐姐在榜上留下了这样一道问题,谁做出来了就可以俘获她的芳心!
爱慕天姐姐已久的泽鸽鸽问询赶来,只见榜上写着:
给定 n,m,定义一种序列,构造方法如下:
1.1. 在 [1,n] 中任意选择 m 次,得到了 m 个整数(显然数字可能相同);
2.2. 将选出的 m 个数字排序之后得到一个序列 {a1,a2,...,am}。
定义一个序列的贡献为 max{a1,a2,...,am}−min{a1,a2,...,am},求所有本质不同的序列的贡献和。
为了防止结果过大,将答案为 998244353 取模后输出。
(对于两个序列长度为m的序列 A、B,若 ∃i∈[1,m],Ai≠Bi,则序列 A、B 本质不同)
泽鸽鸽心有余而力不足,而你作为他最好的基友决定帮助泽鸽鸽俘获美人心!
现在,这个重任就交给你啦!
输入描述:
一行输入两个正整数 n,m
【数据规模与约定】
1 <= n, m <= 5*10^5
输出描述:
一行一个整数,为答案对 998244353 取模后的结果。
示例1
输入
复制
3 2
输出
复制
4
说明
本质不同的序列有如下几种:1 1、2 2、3 3、1 2、1 3、2 3,贡献为 0+0+0+1+2+1=4。
根据问题,计算贡献和大致是dp或者排列组合计数来解。仔细一看发现和DP没啥关系,于是考虑组合计数。
首先根据题意,知道一个序列的最大值和最小值后这个序列的贡献就可以被唯一确定出来,即最大值减最小值。不妨设某个长度为m的序列最大值为mx,最小值为mn,那么这个序列排序后究竟长什么样呢?是mn, .., .., .., ....... .., mx。肯定有一个mx和一个mn是确定的,关键在于剩下m - 2个数有多少种选法。我们用选法数乘以(mx - mn)就能得到最大值为mx最小值为mn的所有序列的总贡献了。而易知剩下m - 2个数肯定都在[mn, mx]内,现在就转化为一个高中常见的排列组合问题:x个相同的小球放入y个盒子,允许有空盒,问一共有多少种放法。答案是C(x + y - 1, y - 1)。对于这个问题而言,相当于m - 2个数分配到[mn, mx]的区间内,允许有数不被覆盖,问有多少种分配方法。即C(m - 2 + mx - mn, mx - mn - 1)。
组合数取模有模版,根据卢卡斯定理,预处理出逆元来就可以计算。那么问题来了:最大值和最小值怎么确定。直接二重循环枚举肯定t得妈妈都不认识,但其实没有必要二重循环。注意到mn = 1, mx = 5和mn = 2, mx = 6的情况其实答案相同,因此我们直接枚举mx - mn(mx - mn其实也是贡献),再乘上这样的区间个数计算即可。
坑点:预处理逆元的时候要开2倍n的大小!这个要根据组合数计算的范围来开!
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <vector>
#include <queue>
#include <map>
#include <set>
#define p 998244353
#define ll long long
using namespace std;
long long n, m;
#define LL long long
const int maxn=1000005;
void extend_gcd(LL a,LL b,LL &x,LL &y){
if(b==0){
x=1,y=0;
return;
}
extend_gcd(b,a%b,y,x);
y-=a/b*x;
}
LL inv[maxn+10];
LL f[maxn+10];
void init(){//阶乘及其逆元打表
f[0]=1;
for(int i=1;i<=maxn;i++){
f[i]=f[i-1]*i%p;
}
LL x,y;
extend_gcd(f[maxn],p,x,y);//先求出f[N]的逆元,再循环求出f[1~N-1]的逆元
inv[maxn]=(x%p+p)%p;
for(int i=maxn-1;i>=1;i--){
inv[i]=inv[i+1]*(i+1)%p;
}
}
LL C(LL n,LL m){
if(n==m||m==0)return 1;
return (f[n]*inv[m]%p*inv[n-m]%p)%p;
}
int main()
{
//freopen("data.txt", "r", stdin);
init();
cin >> n >> m;
long long ans = 0;
for(long long i = 0; i <= n - 1; i++)
{
if(m + i - 2 >= i && m + i - 2 >= 0)
ans = (ans + i * (n - i) % p * C(m + i - 2, i) % p) % p;
//i即mx - mn为贡献,n - i为mx - mn为i的区间的个数,C为贡献
//处理逆元一定要开够n + m的范围
}
cout << ans;
return 0;
}