CF1601C Optimal Insertion

CF1601C Optimal Insertion

题目

题目描述

You are given two arrays of integers $ a_1, a_2, \ldots, a_n $ and $ b_1, b_2, \ldots, b_m $ .

You need to insert all elements of $ b $ into $ a $ in an arbitrary way. As a result you will get an array $ c_1, c_2, \ldots, c_{n+m} $ of size $ n + m $ .

Note that you are not allowed to change the order of elements in $ a $ , while you can insert elements of $ b $ at arbitrary positions. They can be inserted at the beginning, between any elements of $ a $ , or at the end. Moreover, elements of $ b $ can appear in the resulting array in any order.

What is the minimum possible number of inversions in the resulting array $ c $ ? Recall that an inversion is a pair of indices $ (i, j) $ such that $ i < j $ and $ c_i > c_j $ .

输入格式

Each test contains multiple test cases. The first line contains the number of test cases $ t $ ( $ 1 \leq t \leq 10^4 $ ). Description of the test cases follows.

The first line of each test case contains two integers $ n $ and $ m $ ( $ 1 \leq n, m \leq 10^6 $ ).

The second line of each test case contains $ n $ integers $ a_1, a_2, \ldots, a_n $ ( $ 1 \leq a_i \leq 10^9 $ ).

The third line of each test case contains $ m $ integers $ b_1, b_2, \ldots, b_m $ ( $ 1 \leq b_i \leq 10^9 $ ).

It is guaranteed that the sum of $ n $ for all tests cases in one input doesn't exceed $ 10^6 $ . The sum of $ m $ for all tests cases doesn't exceed $ 10^6 $ as well.

输出格式

For each test case, print one integer — the minimum possible number of inversions in the resulting array $ c $ .

题意翻译

题目大意

给定两个序列 \(a,b\),长度分别为 \(n,m(1\leq n,m\leq 10^6)\)。接下来将 \(b\) 中的所有元素以任意方式插入序列 \(a\)任意位置,请找出一种插入方式使结果序列中的逆序对数量最小化,并输出这个最小值。

关于插入:任意方式插入任意位置的示例如下。

例如 \(a=\{1,2,3,4\},b=\{4,5,6\}\),则 \(c=\{4,\underline1,5,\underline2,\underline3,\underline4,6\},\{\underline1,\underline2,6,5,\underline3,4,\underline4\}\dots\) 均为合法的插入方式。但你不能修改 \(a\) 的顺序。

输入格式

本题多测(注意多测不清空爆零两行泪

第一行给定一个正整数 \(t\ (1\leq t\leq 10^4)\) 表示数据组数.

接下来对于每组数据,第一行两个整数 \(n,m\ (1\leq n,m\leq 10^6)\) 分别表示 \(a,b\) 的长度。

第二行包括 \(n\) 个整数,表示 \(a\)

第三行包括 \(m\) 个整数,表示 \(b\)

保证 \(1\leq a_i,b_i\leq 10^9,\ 1\leq \sum n,\sum m\leq 10^6\)

输出格式

对于每组数据一行一个整数,表示最小逆序对数。

输入输出样例

输入 #1

3
3 4
1 2 3
4 3 2 1
3 3
3 2 1
1 2 3
5 4
1 3 5 3 1
4 3 6 1

输出 #1

0
4
6

说明/提示

Below is given the solution to get the optimal answer for each of the example test cases (elements of $ a $ are underscored).

  • In the first test case, $ c = [\underline{1}, 1, \underline{2}, 2, \underline{3}, 3, 4] $ .
  • In the second test case, $ c = [1, 2, \underline{3}, \underline{2}, \underline{1}, 3] $ .
  • In the third test case, $ c = [\underline{1}, 1, 3, \underline{3}, \underline{5}, \underline{3}, \underline{1}, 4, 6] $ .

思路

引理:一定存在一种最优情况,使得\(b\)\(c\)中的顺序是单调递增的,即\(b\)排序后从左到右插入\(a\)得到\(c\).

证明:

假设我们有序列\(ABC\),我们有\(x < y\)且已知将\(x\)插入到\(B,C\)之间最优,我们将\(y\)插入到\(A,B\)\(A,C\)之间.

显然,对于两种方案,\(y\)\(A,C\)两段产生的逆序对数量都是一样的.

\(B\)段中大于\(x\)的数字有\(g_1\)个,小于\(x\)的数有\(s_1\)个,则\(s_1 \ge g_1\).(因为\(x\)插入到\(BC\)之间最优)

\(B\)段中大于\(y\)的数字有\(g_2\)个,小于\(y\)的数有\(s_2\)个,则\(g_2 \le g_1,s_2\ge s_1\).(因为\(y>x\))

所以\(s_2 \ge g_2\),又因为\(y\)\(x\)之前会产生一个\((y,x)\)的逆序对,所以\(y\)放在\(AB\)之间不如放在\(BC\)之间优.

综上,\(x\)插入的位置之后一定存在一个位置,使得\(y\)放在该位置后产生最少的逆序对.

证毕.

所以,我们求一个\(pos_i\)表示\(b_i\)(排序后)插入到\(a_i\)之后(当\(pos_i\)等于\(0\)时,插入到\(a\)的最前端).

可知,\(pos\)是一个单调递增的数组.

因此,我们设\(solve(l_1,r_1,l_2,r_2)\)求解将\(b_{l_1\ldots r_1}\)插入到\(a_{l_2\ldots r_2}\)中的答案.

我们取\(mid = \frac 12 (l_1+r_1)\),用\(O(r_2-l_2)\)的时间求出\(pos_{mid}\),然后递归:\(solve(l_1,mid-1,l_2,pos_{mid})\)\(solve(mid+1,r_1,pos_{mid},r_2)\).

得到\(pos\)之后,树状数组求逆序对即可.

递归次数不超过\(O(m)\),\(solve\)的时间复杂度即\(O(n \log m)\),树状数组求逆序对的时间复杂度为\(O((n+m)\log(n+m))\).

代码

用VScode观看体验更佳

#include <iostream>
#include <cstdio>
#include <algorithm>

//#define int long long
typedef long long ll;

int read() {
	int re = 0;
	char  c = getchar();
	bool negt = false;
	while(c < '0' ||c > '9')
		negt |= (c == '-') , c = getchar();
	while(c >= '0' && c <= '9')
		re = (re << 1) + (re << 3) + c - '0' , c = getchar();
	return negt ? -re : re;
}

const int N = 1000010;

struct TreeArray {
#define lowbit(_) ((_) & -(_))
	ll a[N * 2];//注意数组大小
	int n;
	void set(int n_) {
		n = n_;
		for(int i = 0 ; i <= n ; i++)a[i] = 0;
	}
	void change(int pos , int dat) {
		if(pos == 0)return;
		for( ; pos <= n ; pos += lowbit(pos))a[pos] += dat;
	}
	ll GetSum(int r) {
		ll sum = 0;
		for( ; r > 0 ; r -= lowbit(r))sum += a[r];
		return sum;
	}
	ll GetSum(int l , int r) {
		return GetSum(r) - GetSum(l - 1);
	}
};

int n , m;
int a[N] , b[N];
int pos[N];

void discretize() {//离散化
	static int tmp[N * 2];
	int siz = 0;
	for(int i = 1 ; i <= n ; i++)tmp[++siz] = a[i];
	for(int i = 1 ; i <= m ; i++)tmp[++siz] = b[i];
	std::sort(tmp + 1 , tmp + siz + 1);
	siz = std::unique(tmp + 1 , tmp + siz + 1) - tmp - 1;
	for(int i = 1 ; i <= n ; i++)
		a[i] = std::upper_bound(tmp + 1 , tmp + siz + 1 , a[i]) - tmp - 1;
	for(int i = 1 ; i <= m ; i++)
		b[i] = std::upper_bound(tmp + 1 , tmp + siz + 1 , b[i]) - tmp - 1;
}


#define divideOptimize 1
#if divideOptimize
int pre , suf;
TreeArray statis;
void divide(int l1 , int r1 , int l2 , int r2) {
	if(l1 > r1)return;
	int mid = (l1 + r1) / 2;
	pos[mid] = l2;

    ll inversionNum = (1ll << 60);
    suf = pre = 0;
    for(int i = l2 ; i <= r2 ; i++)suf = suf + (a[i] < b[mid] && i != 0);
    for(int i = l2 ; i <= r2 ; i++) {
        if(inversionNum > suf + pre)
            inversionNum = suf + pre , pos[mid] = i - 1;
        pre += (a[i] > b[mid] && i != 0);
        suf -= (a[i] < b[mid] && i != 0);
    }
    if(inversionNum > suf + pre)
        inversionNum = suf + pre , pos[mid] = r2;

	divide(l1 , mid - 1 , l2 , pos[mid]);
	divide(mid + 1 , r1 , pos[mid] , r2);
}
#else
TreeArray pre , suf;//这里不小心写了个log^2的分治函数,T飞
void divide(int l1 , int r1 , int l2 , int r2) {
	if(l1 > r1)return;
	int mid = (l1 + r1) / 2;
	pos[mid] = l2;

	pre.change(a[l2] , 1) , suf.change(a[l2] , -1);
	int inversionNum = pre.GetSum(b[mid] + 1 , n + m) + suf.GetSum(b[mid] - 1);
	for(int i = l2 + 1 ; i <= r2 ; i++) {
		pre.change(a[i] , 1) , suf.change(a[i] , -1);
		int newNum = pre.GetSum(b[mid] + 1 , n + m) + suf.GetSum(b[mid] - 1);
		if(newNum < inversionNum) {
			inversionNum = newNum;
			pos[mid] = i;
		}
	}

	for(int i = l2 ; i <= r2 ; i++)pre.change(a[i] , -1) , suf.change(a[i] , 1);

	divide(l1 , mid - 1 , l2 , pos[mid]);
	pre.change(a[pos[mid]] , -1) , suf.change(a[pos[mid]] , 1);
	divide(mid + 1 , r1 , pos[mid] , r2);
}
#endif

void solve() {
	n = read() , m = read();
	for(int i = 1 ; i <= n ; i++)a[i] = read();
	for(int i = 1 ; i <= m ; i++)b[i] = read();

	discretize();
	std::sort(b + 1 , b + m + 1);

#if !divideOptimize
	suf.set(n + m) , pre.set(n + m);
	for(int i = 1 ; i <= n ; i++)
		suf.change(a[i] , 1);
#endif
	divide(1 , m , 0 , n);

	statis.set(n + m);//统计逆序对数量
	ll ans = 0;
	int j = 1;
	for(int i = 1 ; i <= n ; i++) {
		while(j <= m && pos[j] < i)ans += statis.GetSum(b[j] + 1 , n + m) , statis.change(b[j] , 1) , ++j;
		ans += statis.GetSum(a[i] + 1 , n + m) , statis.change(a[i] , 1);
	}
	while(j <= m) ans += statis.GetSum(b[j] + 1 , n + m) , statis.change(b[j] , 1) , ++j;
	printf("%lld\n" , ans);
}
signed main() {
	int T = read();
	while(T--)
		solve();
	return 0;
}
/*
1
7 1
13 2 6 4 12 7 11
10

*/

posted @ 2021-11-02 21:28  追梦人1024  阅读(25)  评论(0编辑  收藏  举报