Loading

Educational Codeforces Round 83 (Rated for Div. 2) D. Count the Arrays(排列组合)

Your task is to calculate the number of arrays such that:

  • each array contains 𝑛n elements;
  • each element is an integer from 11 to 𝑚m;
  • for each array, there is exactly one pair of equal elements;
  • for each array 𝑎a, there exists an index 𝑖i such that the array is strictly ascending before the 𝑖i-th element and strictly descendingafter it (formally, it means that 𝑎𝑗<𝑎𝑗+1aj<aj+1, if 𝑗<𝑖j<i, and 𝑎𝑗>𝑎𝑗+1aj>aj+1, if 𝑗≥𝑖j≥i).

Input

The first line contains two integers 𝑛n and 𝑚m (2≤𝑛≤𝑚≤2⋅1052≤n≤m≤2⋅105).

Output

Print one integer — the number of arrays that meet all of the aforementioned conditions, taken modulo 998244353998244353.

Examples

input

Copy

3 4

output

Copy

6

input

Copy

3 5

output

Copy

10

input

Copy

42 1337

output

Copy

806066790

input

Copy

100000 200000

output

Copy

707899035

思路歪了。。正解其实很简单。首先要选出来n - 1个不重复的数,然后需要指定其中的哪个数有一对,根据乘法原理这样的方案是\(C_m^{n - 1}\times (n - 2)\),注意这里乘的是n - 2,因为最大值显然不能重复。最后还需要给每个元素分配到左边或者右边,这需要乘\(2^{n - 3}\)(只需要给n - 3个元素分位置,因为n个位置里一个是拐点两个是重复的)。注意给每个元素分配好位置后得到的序列是唯一的(因为有第四条规则的限制),所以只需要分到左边或者右边即可。

以及n = 2的时候无解,需要特判。

#include <bits/stdc++.h>
#define mod 998244353
#define ll long long 
#define LL long long
#define p 998244353
using namespace std;
ll n, m;
const long long maxn = 200005;
void extend_gcd(LL a,LL b,LL &x,LL &y){
    if(b==0){
        x=1,y=0;
        return;
    }
    extend_gcd(b,a%b,y,x);
    y-=a/b*x;
}

ll inv[maxn + 10];
ll f[maxn + 10];
void init(){//阶乘及其逆元打表
    f[0]=1;
    for(int i=1;i<=maxn;i++){
        f[i]=f[i-1]*i%p;
    }

    LL x,y;
    extend_gcd(f[maxn],p,x,y);//先求出f[N]的逆元,再循环求出f[1~N-1]的逆元
    inv[maxn]=(x%p+p)%p;
    for(int i=maxn-1;i>=1;i--){
        inv[i]=inv[i+1]*(i+1)%p;
    }
}

LL C(LL n,LL m){
	if(m < 0) return 0;
    if(n==m||m==0)return 1;
    return (f[n]*inv[m]%p*inv[n-m]%p)%p;
}
ll fpow(ll a, ll b) {
	if(b < 0) return 0;
	ll ans = 1;
	for(; b; b >>= 1) {
		if(b & 1) ans = ans * a % mod;
		a = a * a % mod;
	}
	return ans;
}
int main() {
	init();
	cin >> n >> m;
	cout << C(m, n - 1) * (n - 2) % mod * fpow(2, n - 3) % mod;
}
posted @ 2021-10-22 22:02  脂环  阅读(35)  评论(0编辑  收藏  举报