AtCoderGC038B - Sorting a Segment 数据结构 + RMQ

题意:给定一个长度为N的排列,你只能对其中长度为K的连续子序列进行一次从小到大的排序,问:排序之后能形成多少不同的排列?

数据范围: 1 <= n, k <= 200,000, k <= n.

-----------------------------------分割线--------------------------------

分析此题,我们发现,长度为K的连续子序列在原排列中只有 N-K+1个,也就是说只会有N-K+1个排序情况,得出答案的上界N-K+1.

考虑上界中有多少连续子序列重复计数M,减去M即为答案。

那么剩下的问题就是统计每一个排序之后的连续子序列相同的个数M了。

朴素做法:枚举每一个长度为K的区间,对区间内从小到大排一下序,得出原排列,与其他排列进行比较,统计相同排列的个数cnt,累加每个cnt-1即可。

时间复杂度  O(N^2*Klog(K)).

思考一下优化方法。

设原排列为A1,A2,A3,........,An。

假设一个区间[l,r]排序之后为原排列为P(l,r).

那么如果P(l1,r1) = P(l2,r2)且 r1 - l1 +1 = r2 - l2 + 1 = K。

当且仅当存在以下两种情况,上式成立:

(1) 区间[l1,r1] 和 区间[l2,r2] 原本就从小到大有序。

(2) 区间[l1,r1] 和 区间[l2,r2]相邻,即 l2 = l1+1,r2 = r1+1,且 min[l1,r2] = a[l1],max[l1,r2] = a[r2]. 

结论(1)的正确性显然。

主要讨论结论(2)的正确性:

我们可以知道,区间[l1,r1] 和 区间[l2, r2] 的区间交为[l2,r1],区间并为[l1,r2]。

如果只考虑区间[l2,r1],那么排序结果显然相同。

而P(l1,r1) <=> P(l2,r1)U 由区间[l1,l2-1]中所有元素基于大小关系插入区间[l2,r1]的相应位置。

区间[l2,r2] 同理。

于是我们只需解决区间[l1,l2-1] 和区间 [r1+1,r2]对区间[l2,r1]的 排序影响。

如果[l1,r1] 与 [l2,r2] 不相邻,且非情况(1),则 P(l1,r1) != P(l2,r1),P(l2,r2)!= P(l2,r1),P(l1,r1)!= P(l2,r1)!= P(l2,r2),不存在。

则当l2 = l1+1 时,若min[l1,r2] = a[l1],则P(l1,r1)= P(l2,r1),若max[l1,r2] = a[r2],则P(l2,r2)= P(l2,r1).由传递性可知:P(l1,r1)= P(l2,r1)= P(l2,r2)。结论成立。

证毕。

于是根据这两个结论,我们可以首先求出情况(1)的重复数,扫一遍原排列,求出长度大于等于K的升序区间数量。

对于情况(2),我们先选取区间[1,K],维护最大值和最小值,接着左端点和右端点指针分别往右移,转移到区间[2,K+1],对于区间[1,K]和[2,K+1],判断是否符合min[1,K+1] = a[1] 并且 max[1,K+1] = a[K+1].若符合,则累加到M中,否则继续往右移,直到右端点到N为止。

维护动态区间最大值和最小值可以用STL的堆 或者 set 维护。

插入删除复杂度O(logN),遍历时间复杂度O(N),总时间复杂度O(NlogN).可以通过。

其实还可以用单调队列维护,总时间复杂度降为O(N),大家有兴趣可以尝试一下(我就不试了QAQ).

堆的代码如下:

#include<bits/stdc++.h>

#define ll long long
#define mp make_pair
#define rep(i, a, b) for(int i = (a);i <= (b);i++)
#define per(i, a, b) for(int i = (a);i >= (b);i--)

using namespace std;

typedef pair<int, int> pii;
typedef double db;
const int N = 1e6 + 50;
int n, k, a[N], cnt = 0, flag = 0; 
int ans = 0, maxx, minn, pmax, pmin;
int vis[N], f[N];
priority_queue < int, vector<int>, greater<int> > q;
priority_queue < int, vector<int>, less<int> > p;
inline int read(){
    int x = 0, f = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9'){if(ch == '-') f = -1; ch = getchar();}
    while(ch >='0' && ch <='9'){x = (x<<3)+(x<<1)+(ch^48); ch = getchar();}
    return x*f;
}
void init(){
    n = read(); k = read(); 
    rep(i, 1, n) a[i] = read();
    rep(i, 1, k) p.push(a[i]), q.push(a[i]);
    rep(i, 2, n){
        if(a[i] > a[i-1]){
            int sum = 0;
            while(a[i] > a[i-1] && i <= n) i++, sum++;
            if(sum >= k-1) cnt ++;
        }
    }
    ans = n-k+1;
    int l = 1, r = k;
    while(r <= n){
        r++;
        if(r > n) break;
        while(vis[q.top()]) q.pop(); 
        while(f[p.top()]) p.pop();
        if(a[l] == q.top() && a[r] > p.top()){
            q.pop(); p.push(a[r]);
            q.push(a[r]);
            f[a[l]] = 1;
            ans --;
        }
        else if(a[l] == q.top() && a[r] < p.top()){
            q.pop(); p.push(a[r]);
            q.push(a[r]);
            f[a[l]] = 1;
        }
        else if(a[l] == p.top()){
            p.pop(); p.push(a[r]);
            q.push(a[r]);
            vis[a[l]] = 1;
        }
        else if(a[l] != q.top() && a[l] != p.top()){
            p.push(a[r]); q.push(a[r]);
            f[a[l]] = 1, vis[a[l]] = 1;
        }
        l++;
    }
    if(!cnt) printf("%d\n", ans);
    else printf("%d\n", ans - cnt+1);
}
int main(){
    init();
    return 0;
}
View Code

STL的<set>代码如下:

#include<bits/stdc++.h>

#define ll long long
#define mp make_pair
#define rep(i, a, b) for(int i = (a);i <= (b);i++)
#define per(i, a, b) for(int i = (a);i >= (b);i--)

using namespace std;

typedef pair<int, int> pii;
typedef double db;
const int N = 1e6 + 50;
int n, k, a[N], ans = 0,  cnt;
set <int> s;
set <int>::iterator it;
inline int read(){
    int x = 0, f = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9'){if(ch == '-') f = -1; ch = getchar();}
    while(ch >='0' && ch <='9'){x = (x<<3)+(x<<1)+(ch^48); ch = getchar();}
    return x*f;
}
void init(){
    n = read(); k = read();
    rep(i, 1, n) a[i] = read();
    rep(i, 1, k) s.insert(a[i]);
    rep(i, 2, n){
        if(a[i] > a[i-1]){
            int sum = 0;
            while(a[i] > a[i-1] && i <= n) i++, sum++;
            if(sum >= k-1) cnt ++;
        }
    }
    ans = n-k+1;
    int l = 1, r = k;
    while(l <= r && r <= n){
        r++;
        if(r > n) break;
        s.insert(a[r]);
        if(*(s.rbegin()) == a[r] && (*s.begin()) == a[l]) ans --;
        s.erase(a[l]); 
        l++;
    }
    if(!cnt) printf("%d\n", ans);
    else printf("%d\n", ans - cnt+1);
}
int main(){
    init();
    return 0;
}
View Code

 备注:本题堆的速度比<set>要快,但是代码实现难度更大,推荐用<set>.

posted @ 2019-09-22 13:36  smilke  阅读(316)  评论(0编辑  收藏  举报