CF1601C Optimal Insertion

传送门


题面:给两个序列\(a,b\),将\(b\)中的所有元素按任意顺序插入\(a\)中,求形成的新的序列的最小逆序对数。


这题首先最好观察出这么个结论:如果把\(b_i\)插在\(p_i\)(即\(a_{i-1}\)\(a_i\)之间)得到的逆序对最少,那么当\(b_i < b_j\)时,一定有\(p_i < p_j\).即这个最优插入位置是随着\(b_i\)增大而单调递增的。


知道这个结论后,如果我们能算出来所有的\(p_i\),那么逆序对只会在\(a\)序列本身和\(a,b\)之间产生,\(b\)本身是不会产生逆序对的。

\(p_i\)的方法,除了\(O(n^2)\)暴力外有两种方法:

  1. 线段树。将\(a\)\(b\)放在一块并从小到大排序。刚开始所有\(a_i\)都比任意一个\(b_i\)大,那么有\(p_i=i-1( i \in [1,n+1])\)。考虑遇到一个\(a_i\),那么接下来的\(b_j\)一定比这个\(a_i\)大,那么放在\(a_i\)后面就不会和\(a_i\)产生逆序对,而放在\(a_i\)前面反而会多产生一个逆序对,因此把\(a_i\)在原数组位置之后的\(p_j\)都减1,而之前的\(p_j\)都加1.
    如果遇到\(b_i\),直接查询当前\(p_j\)的最小值即可。这些操作都可以用线段树实现。
    但是\(a,b\)中可能有相同的元素,假设\(a_i=b_j\),那么将\(b_j\)放在\(a_i\)的前面和后面都不会产生逆序对,对比未用\(a_i\)更新数组的情况,会发现我们应该先将\(a_i\)之后的所有\(p_k\)减1,再查询\(b_j\),最后再把\(a_i\)之前的所有\(p_k\)加1.
    这样线段树的时间复杂度就是\(O(m\log n)\).

  2. 分治。这是题解的做法。先把\(b_i\)从小到大排序,因为\(p_i\)\(b_i\)递增,因此对于当前的一段\(b[L,R]\),先求\(b_{\frac{L+R}{2}}\)的最优位置,再递归到左右子区间求解。这样左右子区间扫描的长度之和就只有当前区间的长度。又因为递归了\(\log n\)层,因此时间复杂度也是\(O(n\log n)\)的。


我这里用的第一种写法,模板化,好写(不过分治也挺好写的)。

#include<bits/stdc++.h>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
#define forE(i, x, y) for(int i = head[x], y; ~i && (y = e[i].to); i = e[i].nxt)
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 1e6 + 5;
In ll read()
{
	ll ans = 0;
	char ch = getchar(), las = ' ';
	while(!isdigit(ch)) las = ch, ch = getchar();
	while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
	if(las == '-') ans = -ans;
	return ans;
}
In void write(ll x)
{
	if(x < 0) x = -x, putchar('-');
	if(x >= 10) write(x / 10);
	putchar(x % 10 + '0');
}

int n, m, N, cnt = 0;
struct Node
{
	int val, pos, flg;
	In bool operator < (const Node& oth)const
	{
		if(val ^ oth.val) return val < oth.val;
		if(flg ^ oth.flg) return flg < oth.flg;
		return pos < oth.pos;
	}
}t[maxn * 3];

int c[maxn];
In int lowbit(int x) {return x & -x;}
In void add(int pos)		//树状数组也要考虑ai相同的情况…… 
{
	for(; pos <= N; pos += lowbit(pos)) c[pos]++;
}
In int query(int pos)
{
	int ret = 0;
	for(; pos; pos -= lowbit(pos)) ret += c[pos];
	return ret;
}

int l[maxn << 2], r[maxn << 2], Min[maxn << 2], lzy[maxn << 2];
In void build(int L, int R, int now)
{
	l[now] = L, r[now] = R;
	lzy[now] = 0;
	if(L == R) {Min[now] = L - 1; return;}
	int mid = (L + R) >>  1;
	build(L, mid, now << 1), build(mid + 1, R, now << 1 | 1);
	Min[now] = min(Min[now << 1], Min[now << 1 | 1]); 
}
In void pushdown(int now)
{
	if(lzy[now])
	{
		Min[now << 1] += lzy[now], lzy[now << 1] += lzy[now];
		Min[now << 1 | 1] += lzy[now], lzy[now << 1 | 1] += lzy[now];
		lzy[now] = 0;
	}
}
In void update(int L, int R, int now, int d)
{
	if(l[now] == L && r[now] == R)
	{
		Min[now] += d, lzy[now] += d;
		return;
	}
	pushdown(now);
	int mid = (l[now] + r[now]) >> 1;
	if(R <= mid) update(L, R, now << 1, d);
	else if(L > mid) update(L, R, now << 1 | 1, d);
	else update(L, mid, now << 1, d), update(mid + 1, R, now << 1 | 1, d);
	Min[now] = min(Min[now << 1], Min[now << 1 | 1]);
}

int main()
{
	int T = read();
	while(T--)
	{
		n = read(), m = read();
		N = n + 1, cnt = 0; 
		fill(c, c + N + 1, 0);
		build(1, N, 1);
		for(int i = 1; i <= n; ++i) 
		{
			int x = read();
			t[++cnt] = (Node){x, i, -1};
			t[++cnt] = (Node){x, i, 1};
		}
		for(int i = 1; i <= m; ++i)
		{
			int x = read();
			t[++cnt] = (Node){x, 0, 0};
		} 
		sort(t + 1, t + cnt + 1);
		ll ans = 0;
		for(int i = 1; i <= cnt; ++i)
		{
			if(t[i].flg == -1)
			{
				ans += query(N) - query(t[i].pos);
				add(t[i].pos);
				update(t[i].pos + 1, N, 1, -1);
			}
			else if(t[i].flg == 0) ans += Min[1];
			else update(1, t[i].pos, 1, 1);
		}
		write(ans), enter;
	}
	return 0;
}
posted @ 2021-11-05 00:55  mrclr  阅读(116)  评论(3编辑  收藏  举报