D. Count the Arrays
Your task is to calculate the number of arrays such that:
- each array contains nn elements;
- each element is an integer from 11 to mm;
- for each array, there is exactly one pair of equal elements;
- for each array aa, there exists an index ii such that the array is strictly ascending before the ii-th element and strictly descending after it (formally, it means that aj<aj+1aj<aj+1, if j<ij<i, and aj>aj+1aj>aj+1, if j≥ij≥i).
Input
The first line contains two integers nn and mm (2≤n≤m≤2⋅1052≤n≤m≤2⋅105).
Output
Print one integer — the number of arrays that meet all of the aforementioned conditions, taken modulo 998244353998244353.
Examples
input
Copy
3 4
output
Copy
6
input
Copy
3 5
output
Copy
10
input
Copy
42 1337
output
Copy
806066790
input
Copy
100000 200000
output
Copy
707899035
Note
The arrays in the first example are:
- [1,2,1][1,2,1];
- [1,3,1][1,3,1];
- [1,4,1][1,4,1];
- [2,3,2][2,3,2];
- [2,4,2][2,4,2];
- [3,4,3][3,4,3].
n个元素有一对相同的,那么n个数中共有n-1个不同的数,从m个数中选n-1,方法数:C(m,n-1)
从n-1个不同的数中选择一个数使其在要构造的数组中出现两次,最大的数是唯一的,不能选它,所以方法数为:n-2
除了重复出现的数一个在最大数的左边,一个在右边外,其他n-3个数可以出现在最大数的左边/右边,方法数为:2^n-3
#include <iostream> #include <vector> #include <algorithm> #include <string> #include <set> #include <queue> #include <map> #include <sstream> #include <cstdio> #include <cstring> #include <numeric> #include <cmath> #include <iomanip> #include <deque> #include <bitset> //#include <unordered_set> //#include <unordered_map> //#include <bits/stdc++.h> //#include <xfunctional> #define ll long long #define PII pair<int, int> #define rep(i,a,b) for(int i=a;i<=b;i++) #define dec(i,a,b) for(int i=a;i>=b;i--) #define pb push_back #define mk make_pair using namespace std; int dir1[6][2] = { { 0,1 } ,{ 0,-1 },{ 1,0 },{ -1,0 },{ 1,1 },{ -1,1 } }; int dir2[6][2] = { { 0,1 } ,{ 0,-1 },{ 1,0 },{ -1,0 },{ 1,-1 },{ -1,-1 } }; const long long INF = 0x7f7f7f7f7f7f7f7f; const int inf = 0x3f3f3f3f; const double pi = 3.14159265358979; const int mod = 998244353; const int N = 1000005; //if(x<0 || x>=r || y<0 || y>=c) inline ll read() { ll x = 0; bool f = true; char c = getchar(); while (c < '0' || c > '9') { if (c == '-') f = false; c = getchar(); } while (c >= '0' && c <= '9') x = (x << 1) + (x << 3) + (c ^ 48), c = getchar(); return f ? x : -x; } inline int add(int x, int y) { x += y; return x >= mod ? x -= mod : x; } inline int sub(int x, int y) { x -= y; return x < 0 ? x += mod : x; } inline int mul(int x, int y) { return 1ll * x * y % mod; } inline int qpow(int x, ll n) { int r = 1; while (n > 0) { if (n & 1) r = 1ll * r * x % mod; n >>= 1; x = 1ll * x * x % mod; } return r; } inline int Inv(int x) { return qpow(x, mod - 2); } namespace Comb { const int maxc = 2000000 + 5; int f[maxc], inv[maxc], finv[maxc]; void init() { inv[1] = 1; for (int i = 2; i < maxc; i++) inv[i] = (mod - mod / i) * 1ll * inv[mod % i] % mod; f[0] = finv[0] = 1; for (int i = 1; i < maxc; i++) { f[i] = f[i - 1] * 1ll * i % mod; finv[i] = finv[i - 1] * 1ll * inv[i] % mod; } } int C(int n, int m) { if (m < 0 || m > n) return 0; return f[n] * 1ll * finv[n - m] % mod * finv[m] % mod; } int S(int n, int m) { // x_1 + x_2 + ... + x_n = m, x_i >= 0 if (n == 0 && m == 0) return 1; return C(m + n - 1, n - 1); } } using Comb::C; int main() { Comb::init(); int n, m; cin >> n >> m; ll p2=1; for (int i = 1; i <= n - 3; i++) { p2 *= 2; p2 %= mod; } ll res = ((C(m, n - 1) % mod)*(p2%mod)%mod)*((n - 2)%mod) % mod; cout << res << endl; return 0; }