【BZOJ5017】[Snoi2017]炸弹 线段树优化建图+Tarjan+拓扑排序

【BZOJ5017】[Snoi2017]炸弹

Description

在一条直线上有 N 个炸弹,每个炸弹的坐标是 Xi,爆炸半径是 Ri,当一个炸弹爆炸时,如果另一个炸弹所在位置 Xj 满足: 
Xi−Ri≤Xj≤Xi+Ri,那么,该炸弹也会被引爆。 
现在,请你帮忙计算一下,先把第 i 个炸弹引爆,将引爆多少个炸弹呢? 

Input

第一行,一个数字 N,表示炸弹个数。 
第 2∼N+1行,每行 2 个数字,表示 Xi,Ri,保证 Xi 严格递增。 
N≤500000
−10^18≤Xi≤10^18
0≤Ri≤2×10^18

Output

一个数字,表示Sigma(i*炸弹i能引爆的炸弹个数),1<=i<=N mod10^9+7。 

Sample Input

4
1 1
5 1
6 5
15 15

Sample Output

32

题解:比较naive的做法就是从i向i能波及到的所有炸弹连边,然后看一下从一个点能走到多少个点就行了,但是边数是O(n^2)的,不过由于i能波及到的炸弹在一个区间中,所以考虑线段树优化建图。

建完图怎么做呢?跑Tarjan+拓扑排序即可。但是拓扑排序是不能传递siz的,怎么办?其实没必要传递siz,因为一个炸弹最终能引爆的炸弹也在一段区间中,所以我们只需要传递这个区间最左边和最右边的点即可。

代码挺长的~其实网上还有更简单的做法~

#include <cstdio>
#include <cstring>
#include <iostream>
#include <queue>
#include <vector>
#include <algorithm>
#define lson x<<1
#define rson x<<1|1
using namespace std;
typedef long long ll;
const int P=1000000007;
const int maxn=2000010;
int n,cnt,tot,sum,top,nn;
int to[40000000],next[40000000],head[maxn],dep[maxn],low[maxn],bel[maxn],sm[maxn],sn[maxn],p[maxn],pos[maxn],Q[maxn];
int sta[maxn],ins[maxn],d[maxn];
ll x[maxn],r[maxn],ans;
vector<int> ch[maxn];
queue<int> q;
inline void add(int a,int b)
{
	//printf("*%d %d\n",a,b);
	to[cnt]=b,next[cnt]=head[a],head[a]=cnt++;
}
void build(int l,int r,int x)
{
	if(l==r)
	{
		p[x]=l,pos[l]=x,nn=max(nn,x);
		return ;
	}
	add(x,lson),add(x,rson);
	int mid=(l+r)>>1;
	build(l,mid,lson),build(mid+1,r,rson);
}
void query(int l,int r,int x,int a,int b,int y)
{
	if(a<=l&&r<=b)
	{
		add(y,x);
		return ;
	}
	int mid=(l+r)>>1;
	if(a<=mid)	query(l,mid,lson,a,b,y);
	if(b>mid)	query(mid+1,r,rson,a,b,y);
}
void tarjan(int x)
{
	dep[x]=low[x]=++tot,sta[++top]=x,ins[x]=1;
	for(int i=head[x];i!=-1;i=next[i])
	{
		if(!dep[to[i]])	tarjan(to[i]),low[x]=min(low[x],low[to[i]]);
		else	if(ins[to[i]])	low[x]=min(low[x],dep[to[i]]);
	}
	if(dep[x]==low[x])
	{
		int t;
		sum++,sm[sum]=0,sn[sum]=1<<30;
		do
		{
			t=sta[top--],ins[t]=0,bel[t]=sum;
			if(p[t])	sm[sum]=max(sm[sum],p[t]),sn[sum]=min(sn[sum],p[t]);
		}while(t!=x);
	}
}
inline ll rd()
{
	ll ret=0,f=1;	char gc=getchar();
	while(gc<'0'||gc>'9')	{if(gc=='-')f=-f;	gc=getchar();}
	while(gc>='0'&&gc<='9')	ret=ret*10+gc-'0',gc=getchar();
	return ret*f;
}
int main()
{
	n=rd();
	int i,j,a,b;
	memset(head,-1,sizeof(head));
	build(1,n,1);
	for(i=1;i<=n;i++)	x[i]=rd(),r[i]=rd();
	for(i=1;i<=n;i++)
		a=lower_bound(x+1,x+n+1,x[i]-r[i])-x,b=upper_bound(x+1,x+n+1,x[i]+r[i])-x-1,query(1,n,1,a,b,pos[i]);
	tarjan(1);
	for(i=1;i<=nn;i++)	for(j=head[i];j!=-1;j=next[j])	if(bel[i]!=bel[to[j]])
		ch[bel[i]].push_back(bel[to[j]]),d[bel[to[j]]]++;
	for(i=1;i<=sum;i++)	if(!d[i])	q.push(i);
	while(!q.empty())
	{
		a=q.front(),q.pop(),Q[++Q[0]]=a;
		for(i=0;i<(int)ch[a].size();i++)
		{
			b=ch[a][i],d[b]--;
			if(!d[b])	q.push(b);
		}
	}
	for(i=sum;i>=1;i--)	for(a=Q[i],j=0;j<(int)ch[a].size();j++)	b=ch[a][j],sm[a]=max(sm[a],sm[b]),sn[a]=min(sn[a],sn[b]);
	for(i=1;i<=n;i++)	ans=(ans+(ll)(sm[bel[pos[i]]]-sn[bel[pos[i]]]+1)*i)%P;
	printf("%lld",ans);
	return 0;
}
posted @ 2017-09-13 15:56  CQzhangyu  阅读(293)  评论(0编辑  收藏  举报