[APIO2015]八邻旁之桥——非旋转treap

题目链接:

[APIO2015]八邻旁之桥

对于$k=1$的情况:

对于起点和终点在同侧的直接计入答案;对于不在同侧的,可以发现答案就是所有点坐标与桥坐标的差之和+起点与终点不在同一侧的人数。

将所有点排序,要使答案最优,桥坐标就是这些点坐标的中位数,用平衡树维护一下求中位数即可。

对于$k=2$的情况:

同样先将起点和终点在同侧的直接计入答案。显然两座桥比一座更优,我们将每个人的起点与终点坐标看成一条线段。那么对于每条线段,它的中点离哪座桥近它就走哪座桥更优。我们将每条线段按中点坐标排序,将所有线段分为两部分,显然左边部分选靠左的桥、右边部分选择靠右的桥。那么只需要枚举中间的分界线,然后两部分分别按$k=1$考虑就行。维护两棵平衡树,每次将第二棵中的两个点(被划为左半部分的线段的起点和终点)删除,插入到第一棵中。

#include<set>
#include<map>
#include<queue>
#include<stack>
#include<cmath>
#include<vector>
#include<cstdio>
#include<bitset>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
int ls[200010];
int rs[200010];
int size[200010];
int v[200010];
ll sum[200010];
int r[200010];
int n,k;
int cnt;
int tot;
int root;
ll ans;
ll res;
int L,R;
int x,y,z;
int a[200010];
char s[2],t[2];
struct lty
{
    int x,y;
}p[100010];
bool operator < (lty a,lty b){return a.x+a.y<b.x+b.y;}
int newnode(int x)
{
    int rt=++cnt;
    r[rt]=rand();
    size[rt]=1;
    v[rt]=x;
    sum[rt]=x;
    return rt;
}
void pushup(int rt)
{
    size[rt]=size[ls[rt]]+size[rs[rt]]+1;
    sum[rt]=sum[ls[rt]]+sum[rs[rt]]+v[rt];
}
int merge(int x,int y)
{
    if(!x||!y)
    {
        return x+y;
    }
    if(r[x]<r[y])
    {
        rs[x]=merge(rs[x],y);
        pushup(x);
        return x;
    }
    else
    {
        ls[y]=merge(x,ls[y]);
        pushup(y);
        return y;
    }
}
void split(int rt,int &x,int &y,int k)
{
    if(!rt)
    {
        x=y=0;
        return ;
    }
    if(size[ls[rt]]>=k)
    {
        y=rt;
        split(ls[rt],x,ls[y],k);
    }
    else
    {
        x=rt;
        split(rs[rt],rs[x],y,k-size[ls[rt]]-1);
    }
    pushup(rt);
}
void split2(int rt,int &x,int &y,int k)
{
    if(!rt)
    {
        x=y=0;
        return ;
    }
    if(v[rt]>=k)
    {
        y=rt;
        split2(ls[rt],x,ls[y],k);
    }
    else
    {
        x=rt;
        split2(rs[rt],rs[x],y,k);
    }
    pushup(rt);
}
int build(int l,int r)
{
    if(l==r)
    {
        return newnode(a[l]);
    }
    int mid=(l+r)>>1;
    return merge(build(l,mid),build(mid+1,r));
}
int del(int &rt,int k)
{
    split2(rt,x,y,k);
    split(y,y,z,1);
    rt=merge(x,z);
    return y;
}
void ins(int &rt,int k,int id)
{
    split2(rt,x,y,k);
    rt=merge(merge(x,id),y);
}
void solve1()
{
    for(int i=1;i<=n;i++)
    {
        scanf("%s%d%s%d",s,&x,t,&y);
        if(s[0]==t[0])
        {
            ans+=abs(y-x);
        }
        else
        {
            ans++;
            a[++tot]=x;
            a[++tot]=y;
        }
    }
    if(tot==0)
    {
        printf("%lld",ans);
        return ;
    }
    sort(a+1,a+1+tot);
    root=build(1,tot);
    int mid=(size[root]+1)/2;
    split(root,x,y,mid-1);
    split(y,y,z,1);
    ans+=1ll*size[x]*v[y]-sum[x];
    ans+=sum[z]-1ll*size[z]*v[y];
    root=merge(merge(x,y),z);
    printf("%lld",ans);
}
void solve2()
{
    for(int i=1;i<=n;i++)
    {
        scanf("%s%d%s%d",s,&x,t,&y);
        if(s[0]==t[0])
        {
            ans+=abs(x-y);
        }
        else
        {
            ans++;
            tot++;
            p[tot].x=x,p[tot].y=y;
            a[tot*2-1]=x,a[tot*2]=y;
        }
    }
    if(tot==0)
    {
        printf("%lld",ans);
        return ;
    }
    sort(p+1,p+1+tot);
    sort(a+1,a+1+tot*2);
    root=build(1,tot*2);
    L=0,R=root;
    ll mn=1ll<<60;
    for(int i=1;i<=tot;i++)
    {
        res=0;
        int l=del(R,p[i].x);
        int r=del(R,p[i].y);
        ins(L,v[l],l);
        ins(L,v[r],r);
        int mid=(size[L]+1)/2;
        split(L,x,y,mid-1);
        split(y,y,z,1);
        res+=1ll*size[x]*v[y]-sum[x];
        res+=sum[z]-1ll*size[z]*v[y];
        L=merge(merge(x,y),z);
        mid=(size[R]+1)/2;
        split(R,x,y,mid-1);
        split(y,y,z,1);
        res+=1ll*size[x]*v[y]-sum[x];
        res+=sum[z]-1ll*size[z]*v[y];
        R=merge(merge(x,y),z);
        mn=min(mn,res);
    }
    ans+=mn;
    printf("%lld",ans);
}
int main()
{
    scanf("%d%d",&k,&n);
    if(k==1)solve1();
    else solve2();
}
posted @ 2019-05-14 22:19  The_Virtuoso  阅读(255)  评论(0编辑  收藏  举报