Codeforces Round #751 (Div. 2) E. Optimal Insertion题解
E. Optimal Insertion
这道逆序对的题真的搞的是头大....
首先根据直观上看,b在最优解的序列中一定递增是最优。这个确实也可以证出来。因为假如在一种情况中,存在\(b_i>b_j\),且\(b_i\)比\(b_j\)更靠前的情况,那我们交换这两个数,首先减少了这一对逆序对,然后本来在这两个数之间的数原本存在的逆序对可能减少,但绝不会增加逆序对的个数,所以b递增的插入一定是最优的。我们考虑如果我们只插一个b的话,我们应该怎么插,显然的是找到一个位置,使得这个位置之前比b大的个数加上这个位置之后比b小的数的个数和最少的位置。我们先将a和b排序。(排序大法好!)考虑a中的每个位置,实际上a中的每个数都可能对这个位置造成贡献(这取决于插入到这个位置的b的值),考虑一个位置,这个位置贡献的逆序对的个数就是这个位置前比他大的a的个数,和这个位置后比他小的a的个数。我们尝试维护这n个位置的n个贡献。注意每个位置中的每个数的贡献只能是0和1.我们先假设所有的a都比b大,然后向右枚举b,尝试给b找位置,这个时候把b小的a枚举。考虑a的枚举给不同的位置带来那些改变,首先,我们比如说当前枚举到\(b_i\),例如当前有一个比他小的\(a_j\),它所在的位置在k,那么考虑位置小于等于k的那些数组,他们这些数组,在k这个元素的贡献应该改为1,对于位置大于k的那些数组,首先要修改他们在k这个元素的贡献,将1改为0,因为每个位置对应的整个和就是这个位置的贡献,所以我们可以考虑用线段树直接维护这些位置的贡献和。即线段树的每个子节点都表示当b放在这个位置时,造成的贡献,我们每次查询最小值即可。根据\(a_j\)直接对(1,k)贡献加1,对(k+1,n+1)的贡献减1.这里容易发现,对于每个递增的b而言,每次前半部分的位置的值总是在增加,后半部分的位置的贡献总是在减少,所以也可以发现,我们只考虑插入每一个b时的最优位置也是递增的。符合我们上面的第一条论述。
#include<bits/stdc++.h>
#define ll long long
#define ls p<<1
#define rs p<<1|1
using namespace std;
const int N=1e6+10;
int T,n,m,a[N],b[N],c[N],d[N],p[N],num;
ll ans;
struct Tree
{
int l,r,dat,tag;
#define l(p) t[p].l
#define r(p) t[p].r
#define dat(p) t[p].dat
#define tag(p) t[p].tag
}t[N*25];
inline bool cmp(int x,int y){return a[x]<a[y];}
inline void build(int p,int l,int r)
{
l(p)=l;r(p)=r;
if(l==r)
{
tag(p)=0;
dat(p)=l-1;
return;
}
int mid=l+r>>1;
build(ls,l,mid);
build(rs,mid+1,r);
dat(p)=min(dat(ls),dat(rs));
tag(p)=0;
}
inline void push(int p)
{
if(tag(p)!=0)
{
dat(ls)+=tag(p);tag(ls)+=tag(p);
dat(rs)+=tag(p);tag(rs)+=tag(p);
tag(p)=0;
}
}
inline void alter(int p,int l,int r,int v)
{
if(l<=l(p)&&r>=r(p))
{
dat(p)+=v;
tag(p)+=v;
return;
}
push(p);
int mid=l(p)+r(p)>>1;
if(l<=mid) alter(ls,l,r,v);
if(r>mid) alter(rs,l,r,v);
dat(p)=min(dat(ls),dat(rs));
}
inline int find(int x){return lower_bound(p+1,p+num+1,x)-p;}
inline void add(int x,int v)
{
for(;x<=n;x+=(x&-x)) c[x]+=v;
}
inline int ask(int x)
{
int as=0;
for(;x;x-=(x&-x)) as+=c[x];
return as;
}
inline void solve()
{
memset(c,0,sizeof(c));
for(int i=1;i<=n;++i) p[i]=a[i];
sort(p+1,p+n+1);ans=0;
num=unique(p+1,p+n+1)-p-1;
for(int i=1;i<=n;++i)
{
int x=find(a[i]);
ans+=ask(num)-ask(x);
add(x,1);
}
}
int main()
{
// freopen("1.in","r",stdin);
int T;scanf("%d",&T);
while(T--)
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i) scanf("%d",&a[i]),d[i]=i;
for(int i=1;i<=m;++i) scanf("%d",&b[i]);
sort(b+1,b+m+1);
sort(d+1,d+n+1,cmp);
solve();
build(1,1,n+1);
int j=0,k=0;//j表示上一次枚举到的相等的地方。
for(int i=1;i<=m;++i)
{
while(k!=j)
{
++k;
alter(1,1,d[k],1);
}
while(j+1<=n&&a[d[j+1]]<b[i])
{
++j;
alter(1,1,d[j],1);
alter(1,d[j]+1,n+1,-1);
}
k=j;
while(j+1<=n&&a[d[j+1]]==b[i])
{
++j;
alter(1,d[j]+1,n+1,-1);
}
int ts=i;
while(ts+1<=m&&b[ts+1]==b[ts]) ++ts;
ans+=(ll)dat(1)*(ts-i+1);
i=ts;
}
printf("%lld\n",ans);
}
return 0;
}