线段树上二分

别样的线段树

D. Points

原题链接:https://codeforces.com/problemset/problem/19/D

开始思路:

看到题目后有一个想法,先将所有坐标进行离散化,在横坐标方向上建立线段树,每个节点维护一个 set 即对应区间 l ~ ry 轴上的坐标,然后每次增删都可以在 O(log2(n)) 内完成,然后查询时,对区间进行直接二分,然后每次将对应区间的集合合并后取出,每次有效性检验检查是否存在大于当前查询的 y,直接二分时间复杂度为 O(log(n)),每次取出时间复杂度最坏情况下可 O(n),每次 upper _ bound 查询为 O(log(n)),三者相乘,不出意外直接 tle

tle代码:

#include<bits/stdc++.h>
using namespace std;    

typedef long long ll;
typedef pair<int,int> PII;

const int N=2e5+10,mod=998244353;

int n,m;

vector<int> all,query;
vector<PII> points;

struct Node{
    int l,r;
    set<int> ys;
}tr[4*N];

void build(int u,int l,int r){
    if(l==r) tr[u]={l,r};
    else{
        tr[u]={l,r};
        int mid=l+r>>1;
        build(u<<1,l,mid),build(u<<1|1,mid+1,r);
    }
}

void add(int u,int x,int y){
    if(tr[u].l==x&&tr[u].r==x) tr[u].ys.insert(y);
    else{
        int mid=tr[u].l+tr[u].r>>1;
        if(x<=mid) add(u<<1,x,y);
        else add(u<<1|1,x,y);
        tr[u].ys.insert(y);
    }
}

void rme(int u,int x,int y){
    if(tr[u].l==x&&tr[u].r==x) tr[u].ys.erase(y);
    else{
        int mid=tr[u].l+tr[u].r>>1;
        if(x<=mid) rme(u<<1,x,y);
        else rme(u<<1|1,x,y);
        tr[u].ys.erase(y);
    }
}

set<int> get(int u,int l,int r){
    if(tr[u].l>=l&&tr[u].r<=r) return tr[u].ys;
    else{
        int mid=tr[u].l+tr[u].r>>1;
        if(r<=mid) return get(u<<1,l,r);
        else if(l>mid) return get(u<<1|1,l,r);
        else{
            set<int> left,right;
            left=get(u<<1,l,r),right=get(u<<1|1,l,r);
            if(left.size()<=right.size()) swap(left,right);
            for(auto y:right) left.insert(y);
            return left;
        }
    }
}

inline bool check(int L,int R,int mi){
    set<int> temp=get(1,L,R);
    return temp.upper_bound(mi)!=temp.end();
}

PII find(int x,int y){
    int l=x+1,r=all.size();
    
    while(l<r){
        int mid=l+r>>1;
        if(check(l,mid,y))r=mid;
        else l=mid+1;  
    }
    
    if(l>r||!check(l,l,y)) return {-1,-1};

    set<int> temp=get(1,l,l);
    int tx=l,ty=*temp.upper_bound(y);
    
    return {all[tx-1],all[ty-1]};
}

void solve(){
    cin>>n;

    query=vector<int>(n);
    points=vector<PII>(n);

    string s;
    int x,y;

    for(int i=0;i<n;i++){
        cin>>s>>x>>y;
        query[i]=(s=="add")?1:(s=="remove")?2:3;    
        points[i]={x,y};
        all.push_back(x),all.push_back(y);
    }

    sort(all.begin(),all.end());
    all.erase(unique(all.begin(),all.end()),all.end());

    build(1,1,all.size());

    for(int i=0;i<n;i++){
        auto [x,y]=points[i];

        x=lower_bound(all.begin(),all.end(),x)-all.begin()+1;
        y=lower_bound(all.begin(),all.end(),y)-all.begin()+1;

        if(query[i]==1) add(1,x,y);
        else if(query[i]==2) rme(1,x,y);
        else{
            PII t=find(x,y);
            if(t.first!=-1) cout<<t.first<<' '<<t.second<<'\n';
            else cout<<-1<<endl;
        }
    }
}

int main() {
    cin.tie(0)->sync_with_stdio(false);
    cout.tie(0);

    int t=1;
    // cin>>t;

    while(t--)solve();
}

正确做法:线段树上二分

这题确实应该用 set,但应该是直接存储对应 x 坐标上的 y坐标,而线段树应该维护对应区间下 y 方向上坐标的最大值。每次插入,先将对应 y 坐标插入到对应 x 坐标下的 set 中,然后再去看对应坐标下存储的最大值是否改变,然后再去对线段树进行修改;对于删除也是上面的思路。而查询则要用线段树二分,固定查询区间 l~r找到是否存在第一个存储大于 y 的横坐标,每次只去找与查询区间有交集的位置,并且在左边符合条件的情况下,优先查询左边。

代码:

#include<bits/stdc++.h>
using namespace std;    

typedef long long ll;
typedef pair<int,int> PII;

const int N=4e5+10,mod=998244353;

int n,m;

vector<int> all,query;
vector<PII> points;

struct Node{
    int l,r;
    int lmy;
}tr[4*N];

set<int> colx[N];

void build(int u,int l,int r){
    if(l==r) tr[u]={l,r,-1};
    else{
        tr[u]={l,r,-1};
        int mid=l+r>>1;
        build(u<<1,l,mid),build(u<<1|1,mid+1,r);
    }
}

void add(int u,int x,int y){
    if(tr[u].l==x&&tr[u].r==x) tr[u].lmy=y;
    else{
        int mid=tr[u].l+tr[u].r>>1;
        if(x<=mid) add(u<<1,x,y);
        else add(u<<1|1,x,y);
        tr[u].lmy=max(tr[u<<1].lmy,tr[u<<1|1].lmy);
    }
}

int find(int u,int l,int low){  // 线段树上二分查询
    if(l>all.size()) return -1;
    if(tr[u].l==tr[u].r){
        if(tr[u].lmy>low) return tr[u].l;
        return -1;
    }
    int mid=tr[u].l+tr[u].r>>1,res=-1;
    if(l<=mid&&tr[u<<1].lmy>low) res=find(u<<1,l,low);
    if(res!=-1) return res;
    if(tr[u<<1|1].lmy>low) return find(u<<1|1,l,low);
    return -1;
}       

void solve(){
    cin>>n;

    query=vector<int>(n);
    points=vector<PII>(n);

    string s;
    int x,y;

    for(int i=0;i<n;i++){
        cin>>s>>x>>y;
        query[i]=(s=="add")?1:(s=="remove")?2:3;    
        points[i]={x,y};
        all.push_back(x),all.push_back(y);
    }

    sort(all.begin(),all.end());
    all.erase(unique(all.begin(),all.end()),all.end());

    build(1,1,all.size());

    unordered_map<int,int> mp;
    for(int i=0;i<all.size();i++) mp[all[i]]=i+1;

    for(int i=0;i<n;i++){
        auto [x,y]=points[i];

        x=mp[x], y=mp[y];

        if(query[i]==1){
            if(colx[x].empty()||*(--colx[x].end())<y) add(1,x,y);
            colx[x].insert(y);
        }
        if(query[i]==2){
            colx[x].erase(y);
            if(colx[x].empty()) add(1,x,-1);
            else add(1,x,*(--colx[x].end()));
        }
        if(query[i]==3){
            int t=find(1,x+1,y);
            if(t==-1) cout<<-1<<endl;
            else{
                int tx=t, ty=*colx[tx].upper_bound(y);
                cout<<all[tx-1]<<' '<<all[ty-1]<<endl;
            }
        }
    }
}

int main() {
    cin.tie(0)->sync_with_stdio(false);
    cout.tie(0);

    int t=1;
    // cin>>t;

    while(t--)solve();
}
posted @   宋佳奇  阅读(13)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】
点击右上角即可分享
微信分享提示