POIXV Permutation


Multiset is a mathematical object similar to a set, but each member of a multiset may have more than one membership. Just as with any set, the members of a multiset can be ordered in many ways. We call each such ordering a permutation of the multiset. For example, among the permutations of the multiset\((1,1,2,3,3,3,7,8)\) there are\((2,3,1,3,3,7,1,8)\) and\((8,7,3,3,3,2,1,1)\) .
We will say that one permutation of a given multiset is smaller (in lexicographic order) than another permutation, if on the first position that does not match the first permutation has a smaller element than the other one. All permutations of a given multiset can be numbered (starting from one) in an increasing order.
Write a programme that

  • reads the description of a permutation of a multiset and a positive integer \(m\) from the standard input,
  • determines the remainder of the rank of that permutation in the lexicographic ordering modulo \(m\),
  • writes out the result to the standard output.


The first line of the standard input holds two integers \(N\) and \(M\) \((1 \le N \le 3 \times 10^5, 2 \le m \le 10^9)\), separated by a single space. These denote, respectively, the cardinality of the multiset and the number \(m\). The second line of the standard input contains \(n\) positive integers \(a_i\) \((1 \le a_i \le 3 \times 10^5)\), separated by single spaces and denoting successive elements of the multiset permutation.


The first and only line of the standard output is to hold one integer, the remainder modulo of the rank of the input permutation in the lexicographic ordering.

Sample Input

4 1000
2 1 10 2

Sample Output



\[X = \sum_{i = 1}^{N}(r_i-1)(N-i)! \]



\[X = \sum_{i = 1}^N \sum_{j < a_i}f_{i,j} \]


\[f_{i,j} = \frac{(N-i)!}{(\prod_{k = 1}^{j-1}c_k)(c_j-1))(\prod_{k = j+1}^{3 \times 10^5}c_k)} \]

首先有$$f_{i,j} = c_j\frac{(N-i)!}{\prod_{k = 1}^{3 \times 10^5}c_k}$$

\[\sum_{j < a_i}f_{i,j} = \frac{(N-i)!}{\prod_{k = 1}^{3 \times 10^5}c_k}\sum_{k = 1}^{a_i-1} c_k \]



\[M = \prod_{i = 1}^qp_i^{d_i} \]

然后我们只要能够计算出\(ans\)在模\(M_i = p_i^{d_i}\)的值,再通过中国剩余定理就可以计算答案了。那么这个怎么求呢?
我们可以将每个数字\(x\)用个二元组\((s,t)\)来表示,\(x = s \times p_i^t\),且\((s,p_i) = 1\)

  • \((s,t) \times (u,v) = (s \times u,t+v)\)
  • \((s,t) / (u,v) = (s \times u^{-1},t-v)\)


using namespace std;

typedef long long ll;
#define lowbit(a) (a&(-a))
#define maxn (300010)
int mod,ans,N,M,tot,A[maxn],aim[maxn],Mi[maxn],Pi[maxn],res[maxn],tree[maxn],num[maxn],tnum[maxn];

inline ll exgcd(ll a,ll b,ll c)
	if (!a) return -1;
	else if (!(c % a)) return c/a;
	ll t = exgcd(b % a,a,((-c % a)+a)%a);
	if (t == -1) return -1;
	return (t*b+c)/a;

inline ll qsm(ll a,int b,int c)
	ll ret = 1;
	for (;b;b >>= 1,(a *= a) %= c) if (b & 1) (ret *= a) %= c;
	return ret;

struct node
	int a,b;
	inline node(int x = 0,int p = 0) { if (!p) return; b = 0; while (!(x % p)) ++b,x /= p; a = x%mod; }
	friend inline node operator * (const node &x,const node &y)
		node ret;
		ret.a = (ll)x.a*(ll)y.a%mod; ret.b = x.b+y.b;
		return ret;
	friend inline node operator / (const node &x,const node &y)
		node ret; int inv = exgcd(y.a,mod,1)%mod;
		ret.a = (ll)x.a*(ll)inv%mod; ret.b = x.b-y.b;
		return ret;
	inline int tran(int p) { return (ll)a*qsm(p,b,mod)%mod; }

inline void ins(int a,int b) { for (;a <= 300000;a += lowbit(a)) tree[a] += b; }
inline int calc(int a) { int ret = 0; for (;a;a -= lowbit(a)) ret += tree[a]; return ret; }

inline void Div(int key)
	for (int i = 2;i*i <= key;++i)
		if (key % i == 0)
			Mi[++tot] = 1; Pi[tot] = i;
			while (key % i == 0) Mi[tot] *= i,key /= i;
	if (key > 1) Mi[++tot] = key,Pi[tot] = key;

inline void init()
	memset(tree,0,sizeof(tree)); memcpy(tnum,num,sizeof(num));
	for (int i = 1;i <= 300000;++i) if (num[i]) ins(i,num[i]);

inline void work(int id)
	init(); mod = Mi[id]; node now(1,Pi[id]);
	for (int i = 1;i < N;++i)
		node tmp(i,Pi[id]);
		now = now*tmp;
	for (int i = 1;i <= 300000;++i)
		for (int j = 2;j <= num[i];++j) { node tmp(j,Pi[id]); now = now/tmp; }
	for (int i = 1,sum;i <= N;++i)
		if (sum = calc(A[i]-1)) res[id] += (now*node(sum,Pi[id])).tran(Pi[id]);
		if (res[id] >= mod) res[id] -= mod; ins(A[i],-1);
		if (i < N)
			node tmp1(N-i,Pi[id]),tmp2(tnum[A[i]]--,Pi[id]);
			now = now*tmp2/tmp1;

inline int crt()
	int ret = 0;
	for (int i = 1;i <= tot;++i)
		int tm = M/Mi[i],inv = exgcd(tm%Mi[i],Mi[i],1)%Mi[i];
		ret += ((ll)res[i]*(ll)inv%M*(ll)tm)%M;
		if (ret >= M) ret -= M;
	return ret;

int main()
	// freopen("permutation.in","r",stdin);
	// freopen("permutation.out","w",stdout);
	scanf("%d %d",&N,&M);
	for (int i = 1;i <= N;++i) scanf("%d",A+i);
	for (int i = 1;i <= N;++i) ++num[A[i]];
	for (int i = 1;i <= tot;++i) work(i);
	ans = crt(); if (++ans >= M) ans -= M;
	// fclose(stdin); fclose(stdout);
	return 0;

