题解 [AGC023E] Inversions
好神奇.jpg
先来考虑一个弱化版问题听说是经典问题:
给定一个长度为 \(n\) 的序列 \(A\),求满足 \(\forall i,P_i\le A_i\) 的 \(1\sim n\) 的排列数量
那么做法是将 \(a\) 排序后,方案数即为
\[tot=\prod\limits_{i=1}^n(a_i-i+1)
\]
即为考虑从小到大填到第 \(i\) 个限制时,已经用了 \(i-1\) 个数,所以这个位置可以填的数有 \(a_i-i+1\) 个
那么回到本题
考虑枚举两个位置,计算这两个位置形成逆序对数的方案数
令 \(i<j\),分 \(a_i\leqslant a_j\) 和 \(a_i>a_j\) 两种情况讨论
当 \(a_i\leqslant a_j\) 时:
因为形成了逆序对,所以要求 \(p_i>p_j\)
那么 \(i, j\) 两处的合法填法数为 \(\frac{(a_i-rk_i+1)(a_i-rk_i)}{2}\cdot\frac{tot}{(a_i-rk_i+1)(a_j-rk_j+1)}\)
再考虑 \(j\) 处填了一个较小的数,会对 \(rk_i<rk_k<rk_j\) 的 \(k\) 的方案数产生影响
那么乘上 \(\prod\limits_{rk_i<rk_k<rk_j}\frac{a_k-rk_k}{a_k-rk_k+1}\)
当 \(a_i>a_j\) 时:
发现 swap(i, j)
后就变成了上一种情况,那么用总方案数减去 \(p_i>p_j\) 的方案数即可
于是问题化为维护一个长这样的式子:
\[\frac{tot(a_i-rk_i)}{2(a_j-rk_j+1)}\prod\limits_{rk_i<rk_k<rk_j}\frac{a_k-rk_k}{a_k-rk_k+1}
\]
发现可以按 \(rk_i\) 递增插入(即 \(a_i\) 递增插入)
则问题变为整体乘,单点赋值和区间查询
复杂度 \(O(n\log n)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 200010
#define fir first
#define sec second
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline ll read() {
ll ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n;
ll tot=1, ans;
pair<int, int> sta[N];
int a[N], rk[N], bit[N];
const ll mod=1e9+7, inv2=(mod+1)>>1;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
inline void add(int i, int dat) {for (; i<=n; i+=i&-i) bit[i]+=dat;}
inline int query(int l, int r) {
int ans=0; --l;
while (r>l) ans+=bit[r], r-=r&-r;
while (l>r) ans-=bit[l], l-=l&-l;
return ans;
}
#define tl(p) tl[p]
#define tr(p) tr[p]
#define pushup(p) val[p]=(val[p<<1]+val[p<<1|1])%mod
int tl[N<<2], tr[N<<2];
ll val[N<<2], tag[N<<2];
inline void spread(int p) {
if (tag[p]==1) return ;
val[p<<1]=val[p<<1]*tag[p]%mod; tag[p<<1]=tag[p<<1]*tag[p]%mod;
val[p<<1|1]=val[p<<1|1]*tag[p]%mod; tag[p<<1|1]=tag[p<<1|1]*tag[p]%mod;
tag[p]=1;
}
void build(int p, int l, int r) {
tl(p)=l; tr(p)=r; tag[p]=1;
if (l==r) return ;
int mid=(l+r)>>1;
build(p<<1, l, mid);
build(p<<1|1, mid+1, r);
pushup(p);
}
void upd(int p, int pos, ll dat) {
if (tl(p)==tr(p)) {val[p]=dat; return ;}
spread(p);
int mid=(tl(p)+tr(p))>>1;
if (pos<=mid) upd(p<<1, pos, dat);
else upd(p<<1|1, pos, dat);
pushup(p);
}
void upd(int p, int l, int r, ll dat) {
if (l<=tl(p)&&r>=tr(p)) {val[p]=val[p]*dat%mod; tag[p]=tag[p]*dat%mod; return ;}
spread(p);
int mid=(tl(p)+tr(p))>>1;
if (l<=mid) upd(p<<1, l, r, dat);
if (r>mid) upd(p<<1|1, l, r, dat);
pushup(p);
}
ll query(int p, int l, int r) {
if (l<=tl(p)&&r>=tr(p)) return val[p];
spread(p);
int mid=(tl(p)+tr(p))>>1; ll ans=0;
if (l<=mid) ans=(ans+query(p<<1, l, r))%mod;
if (r>mid) ans=(ans+query(p<<1|1, l, r))%mod;
return ans;
}
signed main()
{
n=read();
for (int i=1; i<=n; ++i) sta[i]={a[i]=read(), i};
sort(sta+1, sta+n+1);
for (int i=1; i<=n; ++i) rk[sta[i].sec]=i;
for (int i=1; i<=n; ++i) {
if (sta[i].fir-i+1<=0) {puts("0"); return 0;}
tot=tot*(sta[i].fir-i+1)%mod;
}
build(1, 1, n);
for (int i=1; i<=n; ++i) {
ans=(ans+tot*qpow(sta[i].fir-i+1, mod-2)%mod*inv2%mod*query(1, 1, sta[i].sec))%mod;
ans=(ans+query(sta[i].sec, n)*tot-tot*qpow(sta[i].fir-i+1, mod-2)%mod*inv2%mod*query(1, sta[i].sec, n))%mod;
// cout<<"i: "<<i<<' '<<(ans%mod+mod)%mod<<endl;
upd(1, 1, n, (sta[i].fir-i)*qpow(sta[i].fir-i+1, mod-2)%mod);
add(sta[i].sec, 1);
upd(1, sta[i].sec, (sta[i].fir-i)%mod);
}
printf("%lld\n", (ans%mod+mod)%mod);
return 0;
}