[APIO2015]八邻旁之桥——非旋转treap
题目链接:
对于$k=1$的情况:
对于起点和终点在同侧的直接计入答案;对于不在同侧的,可以发现答案就是所有点坐标与桥坐标的差之和+起点与终点不在同一侧的人数。
将所有点排序,要使答案最优,桥坐标就是这些点坐标的中位数,用平衡树维护一下求中位数即可。
对于$k=2$的情况:
同样先将起点和终点在同侧的直接计入答案。显然两座桥比一座更优,我们将每个人的起点与终点坐标看成一条线段。那么对于每条线段,它的中点离哪座桥近它就走哪座桥更优。我们将每条线段按中点坐标排序,将所有线段分为两部分,显然左边部分选靠左的桥、右边部分选择靠右的桥。那么只需要枚举中间的分界线,然后两部分分别按$k=1$考虑就行。维护两棵平衡树,每次将第二棵中的两个点(被划为左半部分的线段的起点和终点)删除,插入到第一棵中。
#include<set> #include<map> #include<queue> #include<stack> #include<cmath> #include<vector> #include<cstdio> #include<bitset> #include<cstring> #include<iostream> #include<algorithm> #define ll long long using namespace std; int ls[200010]; int rs[200010]; int size[200010]; int v[200010]; ll sum[200010]; int r[200010]; int n,k; int cnt; int tot; int root; ll ans; ll res; int L,R; int x,y,z; int a[200010]; char s[2],t[2]; struct lty { int x,y; }p[100010]; bool operator < (lty a,lty b){return a.x+a.y<b.x+b.y;} int newnode(int x) { int rt=++cnt; r[rt]=rand(); size[rt]=1; v[rt]=x; sum[rt]=x; return rt; } void pushup(int rt) { size[rt]=size[ls[rt]]+size[rs[rt]]+1; sum[rt]=sum[ls[rt]]+sum[rs[rt]]+v[rt]; } int merge(int x,int y) { if(!x||!y) { return x+y; } if(r[x]<r[y]) { rs[x]=merge(rs[x],y); pushup(x); return x; } else { ls[y]=merge(x,ls[y]); pushup(y); return y; } } void split(int rt,int &x,int &y,int k) { if(!rt) { x=y=0; return ; } if(size[ls[rt]]>=k) { y=rt; split(ls[rt],x,ls[y],k); } else { x=rt; split(rs[rt],rs[x],y,k-size[ls[rt]]-1); } pushup(rt); } void split2(int rt,int &x,int &y,int k) { if(!rt) { x=y=0; return ; } if(v[rt]>=k) { y=rt; split2(ls[rt],x,ls[y],k); } else { x=rt; split2(rs[rt],rs[x],y,k); } pushup(rt); } int build(int l,int r) { if(l==r) { return newnode(a[l]); } int mid=(l+r)>>1; return merge(build(l,mid),build(mid+1,r)); } int del(int &rt,int k) { split2(rt,x,y,k); split(y,y,z,1); rt=merge(x,z); return y; } void ins(int &rt,int k,int id) { split2(rt,x,y,k); rt=merge(merge(x,id),y); } void solve1() { for(int i=1;i<=n;i++) { scanf("%s%d%s%d",s,&x,t,&y); if(s[0]==t[0]) { ans+=abs(y-x); } else { ans++; a[++tot]=x; a[++tot]=y; } } if(tot==0) { printf("%lld",ans); return ; } sort(a+1,a+1+tot); root=build(1,tot); int mid=(size[root]+1)/2; split(root,x,y,mid-1); split(y,y,z,1); ans+=1ll*size[x]*v[y]-sum[x]; ans+=sum[z]-1ll*size[z]*v[y]; root=merge(merge(x,y),z); printf("%lld",ans); } void solve2() { for(int i=1;i<=n;i++) { scanf("%s%d%s%d",s,&x,t,&y); if(s[0]==t[0]) { ans+=abs(x-y); } else { ans++; tot++; p[tot].x=x,p[tot].y=y; a[tot*2-1]=x,a[tot*2]=y; } } if(tot==0) { printf("%lld",ans); return ; } sort(p+1,p+1+tot); sort(a+1,a+1+tot*2); root=build(1,tot*2); L=0,R=root; ll mn=1ll<<60; for(int i=1;i<=tot;i++) { res=0; int l=del(R,p[i].x); int r=del(R,p[i].y); ins(L,v[l],l); ins(L,v[r],r); int mid=(size[L]+1)/2; split(L,x,y,mid-1); split(y,y,z,1); res+=1ll*size[x]*v[y]-sum[x]; res+=sum[z]-1ll*size[z]*v[y]; L=merge(merge(x,y),z); mid=(size[R]+1)/2; split(R,x,y,mid-1); split(y,y,z,1); res+=1ll*size[x]*v[y]-sum[x]; res+=sum[z]-1ll*size[z]*v[y]; R=merge(merge(x,y),z); mn=min(mn,res); } ans+=mn; printf("%lld",ans); } int main() { scanf("%d%d",&k,&n); if(k==1)solve1(); else solve2(); }