【BZOJ4548】小奇的糖果
Description
有 N 个彩色糖果在平面上。小奇想在平面上取一条水平的线段,并拾起它上方或下方的所有糖果。求出最多能够拾起多少糖果,使得获得的糖果并不包含所有的颜色。
Solution
看到平面坐标系上维护一些点的数量,我们可以用排序降一维。
具体来说,我们将x坐标离散化,然后做一条平行于x轴的扫描线,动态维护当前扫描到的点的数量。
对于这题,我们先考虑在扫描线上方的点(下方的同理),假如这条线段在最下方,那么我们可以枚举一种颜色不选,然后统计颜色相同横坐标最靠近两个点之间的点的数量,然后对答案取 max 。维护这个,可以用双向链表,即维护一个点左(右)边横坐标最大(小)的与它颜色相同的点。
接下来按y坐标排序,扫描线向上移,现在要删掉扫描线上的点,然后我们考虑删掉之后使点数增加的情况。那么对于一个被删掉的点,我们找到它在双向链表前后两个点,那么此时两个点之间的点且在扫描线上方的就可以对答案取 max 。(这里注意先删去点,再统计)
于是把y坐标取反再排序,重新做一次即可。
维护点的数量可以用树状数组或线段树实现。
Code
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define fo(i,j,k) for(int i=j;i<=k;i++)
#define fd(i,j,k) for(int i=j;i>=k;i--)
#define N 100010
using namespace std;
struct node{
int x,y,c;
int nx,wz;
}b[N*2];
bool cmpc(node x,node y){
if(x.c==y.c) return x.x<y.x;
else return x.c<y.c;
}
bool cmpy(node x,node y){
return x.y<y.y;
}
bool cmpx(node x,node y){
return x.nx<y.nx;
}
int mx=0,n,m;
int nx[N*2],ls[N*2];
int sn[N*2],sl[N*2];
bool bz[N];
void link(int x,int y){
nx[x]=y,ls[y]=x;
}
void cut(int x){
nx[ls[x]]=nx[x],ls[nx[x]]=ls[x];
}
int ans=0;
int top[N],tr[N];
int lowbit(int x){
return x & -x;
}
void add(int x,int t){
while(x<=mx) tr[x]+=t,x+=lowbit(x);
}
int sum(int x){
int tmp=0;
while(x) tmp+=tr[x],x-=lowbit(x);
return tmp;
}
int tot;
int wz[N];
void calc()
{
sort(b+1,b+n+1,cmpy);
fo(i,1,n) wz[b[i].wz]=i;
int p=0;
fo(i,1,n)
{
while(b[p+1].y==b[i].y && p<n) add(b[++p].x,-1);
fo(j,i,p)
{
int o=b[j].wz,tmp=0;
int l=ls[o],r=nx[o];
if(l<=n) l=wz[l];
if(r<=n) r=wz[r];
if(b[l].x<b[r].x-1) tmp=sum(b[r].x-1)-sum(b[l].x);
if(tmp>ans) ans=tmp;
cut(o);
}
i=p;
}
}
int main()
{
int T;
scanf("%d",&T);
while(T--)
{
scanf("%d %d",&n,&m);
fo(i,1,n) scanf("%d %d %d",&b[i].nx,&b[i].y,&b[i].c),bz[b[i].c]=true;
int cnt=0;
fo(i,1,m) if(bz[i]) cnt++;
if(cnt<m)
{
printf("%d\n",n);
continue;
}
sort(b+1,b+n+1,cmpx);
mx=0;
b[1].x=++mx;
fo(i,2,n) if(b[i].nx!=b[i-1].nx) b[i].x=++mx;
else b[i].x=mx;
tot=n;
sort(b+1,b+n+1,cmpc);//Datum
ans=0;
fo(i,1,n)
{
add(b[i].x,1);
b[i].wz=i;
if(b[i].c!=b[i-1].c) top[b[i].c]=++tot,link(tot,i);
else link(i-1,i);
if(b[i].c!=b[i+1].c) b[++tot].x=mx+1,link(i,tot);
}
fo(i,1,m)
{
int j=top[i];
while(nx[j])
{
if(b[j].x<b[nx[j]].x-1)
{
int tmp=sum(b[nx[j]].x-1)-sum(b[j].x);
if(tmp>ans) ans=tmp;
}
j=nx[j];
}
}
memcpy(sn,nx,sizeof(nx));
memcpy(sl,ls,sizeof(ls));
calc();
fo(i,1,n) b[i].y=-b[i].y,add(b[i].x,1);
memcpy(nx,sn,sizeof(sn));
memcpy(ls,sl,sizeof(sl));
calc();
printf("%d\n",ans);
memset(nx,0,sizeof(nx));
memset(ls,0,sizeof(ls));
memset(bz,0,sizeof(bz));
}
}