【YbtOJ#463】序列划分
题目
题目链接:https://www.ybtoj.com.cn/contest/115/problem/1
\(n\leq 10^5,m\leq 10^{12},a_i,b_i\leq 2\times 10^9\)。
思路
我们记 \(\text{nxt}_i\) 表示满足 \(a_j\geq b_i\) 的最大的 \(j\)。那么我们可以把序列分成若干段,其中第 \(i\) 段是 \([l_i,r_i]\) 且满足 \(\max^{r_i}_{j=l_i}(\text{nxt}_j)\leq r_i\)。
接下来 \(a_i,b_i\) 表示第 \(i\) 段 \(a\) 的最大值,\(b\) 的前缀和。
那么显然最终的划分一定是把若干相邻的块合并。考虑二分每一段 \(b\) 的和的最大值,然后 dp 判定是否有解。
设 \(f_i\) 表示前 \(i\) 段合并后,每一段合法时 \(a\) 的最大值之和是多少。考虑加入第 \(i\) 段时:
\[f_i=\min^{i-1}_{b_i-b_j\leq \text{mid}}(f_j+\max^{i}_{k=j+1}(a_k))
\]
发现每次加入 \(i\) 后 \(\max\) 受影响的是一段后缀,然后查询也是一段区间的最小值,直接上线段树优化 dp 即可。
时间复杂度 \(O(n\log n\log a)\)。
代码
#include <bits/stdc++.h>
#define int long long
using namespace std;
typedef long long ll;
const int N=100010,Inf=1e18;
int n,n1,mr,nxt[N],pre[N],lft[N];
ll m,maxa,sumb,a[N],b[N],f[N],maxn[N];
struct node
{
ll a,id;
}c[N];
bool cmp(node x,node y)
{
return (x.a==y.a)?(x.id>y.id):(x.a>y.a);
}
void init()
{
for (int i=n;i>=1;i--)
maxn[i]=max(maxn[i+1],a[i]);
for (int i=1;i<=n;i++)
{
int l=i+1,r=n,mid;
while (l<=r)
{
mid=(l+r)>>1;
if (maxn[mid]>=b[i]) l=mid+1;
else r=mid-1;
}
nxt[i]=l-1; mr=max(mr,nxt[i]);
maxa=max(maxa,a[i]); sumb+=b[i];
if (mr==i)
{
a[++n1]=maxa; b[n1]=sumb+b[n1-1];
maxa=sumb=0;
}
}
n=n1;
}
void prework()
{
set<int> s;
s.insert(0); s.insert(Inf);
for (int i=1;i<=n;i++)
c[i]=(node){a[i],i};
sort(c+1,c+1+n,cmp);
for (int i=1;i<=n;i++)
{
pre[c[i].id]=*(--s.lower_bound(c[i].id));
s.insert(c[i].id);
}
}
struct SegTree
{
ll val[N*4],minn[N*4],lazy[N*4];
void clr()
{
memset(minn,0x3f3f3f3f,sizeof(minn));
memset(val,0x3f3f3f3f,sizeof(val));
memset(lazy,0,sizeof(lazy));
}
void pushdown(int x,int l,int r)
{
if (lazy[x])
{
val[x*2]=val[x*2+1]=lazy[x];
lazy[x*2]=lazy[x*2+1]=lazy[x];
minn[x*2]=f[l-1]+val[x*2];
minn[x*2+1]=f[r-1]+val[x*2+1];
lazy[x]=0;
}
}
void pushup(int x)
{
minn[x]=min(minn[x*2],minn[x*2+1]);
}
void update(int x,int l,int r,int ql,int qr,ll v,bool typ)
{
if (ql>qr) return;
if (ql<=l && qr>=r)
{
if (!typ)
{
val[x]=lazy[x]=v;
minn[x]=f[l-1]+v;
}
else minn[x]=v+val[x];
return;
}
int mid=(l+r)>>1;
pushdown(x,l,mid+1);
if (ql<=mid) update(x*2,l,mid,ql,qr,v,typ);
if (qr>mid) update(x*2+1,mid+1,r,ql,qr,v,typ);
pushup(x);
}
ll query(int x,int l,int r,int ql,int qr)
{
if (ql>qr) return Inf;
if (ql<=l && qr>=r) return minn[x];
int mid=(l+r)>>1; ll res=Inf;
pushdown(x,l,mid+1);
if (ql<=mid) res=min(res,query(x*2,l,mid,ql,qr));
if (qr>mid) res=min(res,query(x*2+1,mid+1,r,ql,qr));
return res;
}
}seg;
bool check(ll mid)
{
for (int i=1,j=0;i<=n;i++)
{
while (b[i]-b[j]>mid) j++;
lft[i]=j;
}
seg.clr();
seg.update(1,1,n+1,1,1,0,1);
for (int i=1;i<=n;i++)
{
seg.update(1,1,n+1,pre[i]+1,i,a[i],0);
f[i]=seg.query(1,1,n+1,lft[i]+1,i);
seg.update(1,1,n+1,i+1,i+1,f[i],1);
}
return f[n]<=m;
}
signed main()
{
freopen("sequence.in","r",stdin);
freopen("sequence.out","w",stdout);
scanf("%lld",&n); scanf("%lld",&m);
for (int i=1;i<=n;i++)
scanf("%lld%lld",&a[i],&b[i]);
init(); prework();
ll l=1,r=2e14,mid;
while (l<=r)
{
mid=(l+r)>>1;
if (check(mid)) r=mid-1;
else l=mid+1;
}
printf("%lld\n",r+1);
return 0;
}