51Nod 1810 连续区间

https://www.51nod.com/onlineJudge/questionCode.html#!problemId=1810

题目给出一个1~n的排列,问有多少连续区间。连续区间的定义为区间内元素排序后之间间隔为1。

对于一个区间[l,r],令mid=(l+r)/2,我们如果能在O(n)内求解出左端点在[l,mid],右端点在[mid+1,r]的连续区间数量,就可以将问题一分为二,递归求解[l,mid] [mid+1,r]。

现在来求解上面所说的这个子问题,首先默认i<j,有一个结论max[i~j]-min[i~j]==j-i时[i,j]是一个连续区间。所以我们维护两组从mid(mid+1)出发,向左(右)延伸的后(前)缀max和min数组。max[i~j]=max(max[i],max[j]),min同理。我们只要找到所有i,j组合使得结论式成立即可。

但是很明显朴素枚举是n^2的,我们能不能优化到n呢?

观察max[i~j],min[i~j],他们的来源有4种组合,因为是两两对应的,我们其中两个组合来讨论,另外两个可以同理推导。

第一种是max和min都来自mid左边。有max[i]-min[i]=j-i,推导得 j=max[i]-min[i]+i,只要枚举左半边的i,并判断j是否合法即可。

第二种是max来自左边min来自右边。有max[i]-min[j]=j-i,推导得max[i]+i=min[j]+j,枚举每个max[i]+i,计算有多少合法的j符合即可。说得轻松,这个j的数量怎么计算呢?

我们发现max想要来自左边,需要满足max[i]>max[j],同理min[i]>min[j]。同时我们发现从中心向外发散,max递增,min递减,也就是从中心向左枚举i的过程中max限制越来越宽松,min越来越严格。我们lp,rp来维护右半区间中满足当前i的[mid+1,r]的j的子窗口,可以预见的是随着i--,这个窗口会向右尺取。我们每加入一个j,就让一个计数数组中的cnt[min[j]+j]++,每剔除一个j则相应减减。对于每个i,完成尺取后,ans+=cnt[max[i]+i]。

对应搞定剩下两个情况后这题就理论AC了,剩下还有一些细节,比如l==r的处理,cnt数组的复原处理,输入挂优化什么的,搞搞就AC了。

#include <iostream>
#include <cmath>
#include <algorithm>
#include <map>
#include <cstring>
#define LL long long
using namespace std;
const LL N = 1000005;
int num[N],n;
LL ans;
int mx[N], mi[N];
int read() {
    char ch;
    for (ch = getchar(); ch<'0' || ch>'9'; ch = getchar());
    int x = ch - '0';
    for (ch = getchar(); ch >= '0'&&ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
    return x;
}
void make_pre(LL l, LL r, LL mid)
{
    mi[mid] = mx[mid] = num[mid];
    mi[mid + 1] = mx[mid + 1] = num[mid + 1];
    for (int i = mid - 1; i >= l; i--)
    {
        mx[i] = max(num[i], mx[i + 1]);
        mi[i] = min(num[i], mi[i + 1]);
    }
    for (int i = mid + 2; i <= r; i++)
    {
        mx[i] = max(num[i], mx[i - 1]);
        mi[i] = min(num[i], mi[i - 1]);
    }
}
struct tong
{
    LL cnt[N * 3];
    void clear()
    {
        memset(cnt, 0, sizeof(cnt));
    }
    void setZero(LL num)
    {
        cnt[num + N] = 0;
    }
    void add(LL num,int v)
    {
        cnt[num + N]+=v;
    }
    LL query(LL num)
    {
        return cnt[num+N];
    }
}cnt;

void solve(LL l, LL r)
{
    
    int temp = ans;
    LL mid = (l + r) /2;
    make_pre(l, r, mid);
    //same i
    for (int i = l; i <= mid; i++)
    {
        int nj = mx[i] - mi[i]+i;
        if (nj>mid&&nj<=r&&mx[i]>mx[nj] && mi[i]<mi[nj]) ans++;
    }
    //same j
    for (int i = mid+1; i <= r; i++)
    {
        int nj = mx[i] - mi[i]-i;
        nj = -nj;
        if (nj<=mid&&nj>=l&&mx[i]>mx[nj] && mi[i]<mi[nj]) ans++;
    }
    //dif mx[i],mi[j]
    LL pl = mid+1,pr=mid+1;
    for (int i = mid; i>=l; i--)
    {
        while (pr <= r&&mx[pr] < mx[i])cnt.add(mi[pr] + pr,1),pr++;
        while (pl < pr&&mi[pl] > mi[i])cnt.add(mi[pl] + pl, -1), pl++;
        //if (cnt.query(mx[i] + i) < 0) cout << l << ' ' << r << ' ' << pl << endl;
        ans += cnt.query(mx[i] + i);
    }
    while (pl < pr)cnt.setZero(mi[pl]+pl),pl++;
    //dif mi[i],mx[j]
    pl = mid , pr = mid;
    for (int i=mid+1; i<=r; i++)
    {
        while (pr >=l&&mx[pr] < mx[i])cnt.add(mi[pr] - pr, 1), pr--;
        while (pl > pr&&mi[pl] > mi[i])cnt.add(mi[pl] - pl, -1), pl--;
        //if(cnt.query(mx[i] - i)<0)cout << l << ' ' << r << ' ' << pl << endl;
        ans += cnt.query(mx[i] - i);
    }
    while (pl > pr)cnt.setZero(mi[pl] - pl), pl--;
    //cout << ans - temp << ' ' << l << ' ' << r << endl;
    if (l == r)return;
    solve(l, mid);
    solve(mid + 1, r);
}
int main()
{
    //cin.sync_with_stdio(false);
    n = read();
        for (int i = 0; i < n; i++)num[i]=read();
        ans = 0;
        cnt.clear();
        solve(0, n - 1);
        printf("%lld\n", ans+n);
    
    return 0;
}

 

posted @ 2017-08-23 23:30  Luke_Ye  阅读(324)  评论(0编辑  收藏  举报