逆序对[AHOI2008]
【题目描述】
小可可和小卡卡想到Y岛上旅游,但是他们不知道Y岛有多远。好在,他们找到一本古老的书,上面是这样说的: 下面是\(N\)个正整数,每个都在\(1\sim K\)之间。如果有两个数\(A\)和\(B\),\(A\)在\(B\)左边且\(A\)大于\(B\),我们就称这两个数为一个“逆序对”。你数一数下面的数字里有多少个逆序对,你就知道Y岛离这里的距离是多少千米了。 比如说,\(4,2,1,3,3\)里面包含了\(5\)个逆序对:\((4, 2), (4, 1), (4, 3), (4, 3), (2, 1)\)。 可惜的是,由于年代久远,这些数字里有一部分已经模糊不清了,为了方便记录,小可可用\(-1\)表示它们。比如说,\(4,2,-1,-1,3\) 可能原来是\(4,2,1,3,3\),也可能是\(4,2,4,4,3\),也可能是别的样子。小可可希望知道,根据他们看清楚的这部分数字,能不能推断出这些数字里最少能有多少个逆序对。
【输入格式】
第一行两个正整数 \(N\) 和 \(K\) 。第二行 \(N\) 个整数,每个都是 \(-1\) 或是一个在 \(1\sim K\) 之间的数。
\(N\le 10000, K\le 100\)
【输出格式】
一个正整数,即这些数字里最少的逆序对个数。
题解
首先证明一下\(-1\)的地方填的数一定是单调不降的
假设\(A\le B\) 那么现在的逆序对数是\(a+b+y\) 如果交换\(a,b\) 那么逆序对数会变成\(a+b+x\)或者\(a+b+x+1\)(取决于\(A=B\))
由于\(A\le B\) 所以\(y\le x\) 所以\(a+b+y\le a+b+x<a+b+x+1\) 在\(A,B\)右边小于\(A,B\)的也同理 不证了 所以肯定是不交换,即\(A\le B\)比较优
我们完全可以先把那些不是\(-1\)的元素之间的逆序对先算出来 然后我们就只用计算每个\(-1\)位置产生的逆序对数了
预处理出\(sumfr[i][j]\)表示\(a[1\sim i]\)除\(-1\)外有多少个大于\(j\)的数,\(sumbk[i][j]\)表示\(a[i\sim n]\)除\(-1\)外有多少个小于\(j\)的数
子状态:\(dp[i][j]\)表示前\(i\)个\(-1\) 最后一个填了\(j\)的最小逆序对数 \(pos[i]\)表示第\(i\)个\(-1\)在原数组中的下标
由于我们刚才推出了那个单调不降的性质 转移方程就很简单了:\(dp[i][j]=\min\limits_{k=1}^{j}dp[i-1][k]+sumfr[pos[i]-1][j]+sumbk[pos[i]+1][j]\)
时间复杂度\(O(nk)\)
代码
#include <bits/stdc++.h>
#define lowbit(x) x&(-x)
#define N 20005
#define K 205
using namespace std;
typedef long long ll;
inline int read() {
int x = 0, f = 1; char ch = getchar();
for (; ch > '9' || ch < '0'; ch = getchar()) if (ch == '-') f = -1;
for (; ch <= '9' && ch >= '0'; ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ '0');
return x * f;
}
int n, m, a[N];
ll sumfr[N][K], sumbk[N][K];
ll dp[N][K];
ll tr[N], ans;
int q[N], tot;
inline void update(int ind, int v) {
for (; ind <= m; ind += lowbit(ind)) {
tr[ind] += v;
}
}
inline ll getsum(int ind) {
ll ret = 0;
for (; ind; ind -= lowbit(ind)) {
ret += tr[ind];
}
return ret;
}
inline void getans() {
ans = 0;
memset(tr, 0, sizeof(tr));
for (int i = 1; i <= n; i++) {
if (a[i] == -1) continue;
ans += getsum(m) - getsum(a[i]);
update(a[i], 1);
}
}
inline void init() {
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++) {
sumfr[i][j] = sumfr[i-1][j];
if (a[i] != -1) sumfr[i][j] += (a[i] > j);
}
}
for (int i = n; i; i--) {
for (int j = 1; j <= m; j++) {
sumbk[i][j] = sumbk[i+1][j];
if (a[i] != -1) sumbk[i][j] += (a[i] < j);
}
}
for (int i = 1; i <= n; i++) {
if (a[i] == -1) {
q[++tot] = i;
}
}
}
int main() {
n = read(); m = read();
for (int i = 1; i <= n; i++) a[i] = read();
getans();
init();
memset(dp, 0x3f, sizeof(dp));
for (int j = 1; j <= m; j++) dp[0][j] = 0;
for (int i = 1; i <= tot; i++) {
for (int j = 1; j <= m; j++) {
dp[i][j] = min(dp[i][j-1], dp[i-1][j]);
}
for (int j = 1; j <= m; j++) {
dp[i][j] += sumfr[q[i]-1][j] + sumbk[q[i]+1][j];
}
}
ll mn = 0x3f3f3f3f3f3f3f3f;
for (int j = 1; j <= m; j++) {
mn = min(mn, dp[tot][j]);
}
printf("%lld\n", ans + mn);
return 0;
}