【POJ1716】Integer Intervals——差分约束||贪心
题目大意:给出n个区间,现在要你找出一个点集,使得这n个区间都至少有2个元素在这个点集里面,问这个点集最少有几个点。
解法一:差分约束系统
分析:其实这道题应该说是POJ1201的简化版,不过要注意的一点是,如果你用的是SPFA,那么你的差分约束系统应该为:
s[b+1]-s[a]>=2;
s[b+1]-s[b]>=0;
s[b]-s[b+1]>=1.
为什么下标要全部加上1呢?因为这里的a和b有可能为0,如果按照原来s[a-1]的写法会出现是s[-1]这类数组越界的问题。
代码:
#include<cstdio> #include<algorithm> #include<cstring> const int maxn=1e4+5,inf=0x3f3f3f3f; using namespace std; struct point{ int next,w,to; }e[maxn*3]; int s,t,tot=0,first[maxn],minn=inf,maxx=0,q[maxn]; int dis[maxn]; bool vis[maxn]; int read() { int ans=0,f=1;char c=getchar(); while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();} while(c>='0'&&c<='9'){ans=ans*10+c-48;c=getchar();} return ans*f; } void add(int u,int v,int wi) { tot++;e[tot].next=first[u];first[u]=tot;e[tot].to=v;e[tot].w=wi; } void spfa() { int head=0,tail=1; q[head]=minn;dis[minn]=0;vis[minn]=1; while(head!=tail){ int x=q[head];head++;if(head>=1001)head=0; for(int i=first[x];i;i=e[i].next){ int to=e[i].to; if(dis[to]<dis[x]+e[i].w){ dis[to]=dis[x]+e[i].w; if(!vis[to]){ vis[to]=1; q[tail]=to; tail++; if(tail>=1001)tail=0; } } } vis[x]=0; } } int main() { int n=read(),a,b; for(int i=1;i<=n;i++){ a=read();b=read(); add(a,b+1,2); minn=min(minn,a); maxx=max(maxx,b+1); } for(int i=minn;i<=maxx;i++){ add(i,i+1,0); add(i+1,i,-1); } memset(dis,-127,sizeof(dis)); memset(vis,0,sizeof(vis)); spfa(); printf("%d",dis[maxx]); return 0; }
解法二:贪心
分析:其实这道题贪心更好写而且可以跑得更快。先把每段区间按照右端点从小到大排序,贪心的基本策略就是尽量取这个区间最后两个整数。这里需要两个变量x和y,x初始化为第一个区间的倒数第二个元素,y为倒数第一个元素。每次枚举到下一个区间时,若该区间已经包含有x和y这两个元素,则直接跳转到下一个区间;若只包含了其中一个元素(因为x<y所以该元素必然是y),则x=y,y=e[i].b,同时sum++;若两个元素都不包含,则将x和y重新更新为当前区间的最后两个值,同时sum+2。
代码:
#include<cstdio> #include<algorithm> #include<cstring> const int maxn=1e4+5; using namespace std; struct point{ int a,b; }e[maxn]; int read() { int ans=0,f=1;char c=getchar(); while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();} while(c>='0'&&c<='9'){ans=ans*10+c-48;c=getchar();} return ans*f; } bool cmp(point c,point d){return c.b<d.b;} int main() { int n=read(),x,y,sum=2; for(int i=1;i<=n;i++){ e[i].a=read();e[i].b=read(); } sort(e+1,e+1+n,cmp); x=e[1].b-1;y=e[1].b; for(int i=2;i<=n;i++){ if(e[i].a<=x)continue; if(e[i].a<=y){ sum++;x=y;y=e[i].b;continue; } x=e[i].b-1;y=e[i].b;sum+=2; } printf("%d",sum); return 0; }