XM6138 分糖果(线段树)
XM6138 分糖果
stO jc Orz
解题思路
很神奇的一题。
看到题目首先想到的就是二分答案。那么问题转化为了如何验证一个答案 \(x\)。
我们定义函数 \(f_{i,j}\) 表示考虑前 \(i\) 个小朋友,是否可以分成 \(j\) 段,使得每一段都满足和小于 \(x\)。这个函数很好转移:
\[f_{i,j}\leftarrow f_{k,j-1}\ \ sum_i-sum_k\leq x
\]
其中 \(sum\) 表示前缀和。
目前转移是 \(O(n^3)\) 的,需要优化。那么我们以 \(sum_i\) 为关键字进行排序,构建线段树快速找出 \(k\) 这一维,使得 DP 复杂度来到了 \(O(n^2 \log n)\)。
但是这还远远不够,我们还有一个二分的 \(\log\)。我们通过观察可以发现,满足 \(f_{i,j}=1\) 的 \(j\) 一定实际一个连续的区间。那么我们在线段树上维护满足条件的 \(j\) 的区间即可。最终总时间复杂度为 \(O(n\log^2 n)\)。
现在问题就是,为什么 \(j\) 一定是连续的?在讨论后,我们把问题扔给了 IOI 金牌 jc,然后被秒了(
\(x\leq 0\) 时显然,那 \(x>0\) 时呢?我们可以假设 \(j=x,x+2\) 时成立,那么我们通过容斥发现,\(j=x+1\) 一定成立。至于如何容斥的,读者自证不难。
代码
代码很简单。
//Don't act like a loser.
//This code is written by huayucaiji
//You can only use the code for studying or finding mistakes
//Or,you'll be punished by Sakyamuni!!!
#include<bits/stdc++.h>
using namespace std;
int read() {
char ch=getchar();
int f=1,x=0;
while(ch<'0'||ch>'9') {
if(ch=='-')
f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9') {
x=x*10+ch-'0';
ch=getchar();
}
return f*x;
}
char read_char() {
char ch=getchar();
while(!isalpha(ch)) {
ch=getchar();
}
return ch;
}
const int MAXN=2e4+10;
int n,k,m;
int a[MAXN],sum[MAXN],b[MAXN],idx[MAXN];
struct seg {
int l,r;
}s[MAXN<<2];
seg pushup(seg lft,seg rgt) {
seg ret=s[0];
ret.l=min(lft.l,rgt.l);
ret.r=max(lft.r,rgt.r);
return ret;
}
void build(int l,int r,int p) {
s[p].l=1e9;
s[p].r=-1e9;
if(l==r) {
return ;
}
int mid=(l+r)>>1;
build(l,mid,p<<1);
build(mid+1,r,p<<1|1);
}
void modify(int l,int r,int p,int x,int v1,int v2) {
if(x<l||r<x) {
return ;
}
if(l==r) {
s[p].l=min(s[p].l,v1);
s[p].r=max(s[p].r,v2);
return ;
}
int mid=(l+r)>>1;
modify(l,mid,p<<1,x,v1,v2);
modify(mid+1,r,p<<1|1,x,v1,v2);
s[p]=pushup(s[p<<1],s[p<<1|1]);
}
seg query(int l,int r,int p,int x,int y) {
if(x<=l&&r<=y) {
return s[p];
}
int mid=(l+r)>>1;
if(y<=mid) {
return query(l,mid,p<<1,x,y);
}
if(mid<x) {
return query(mid+1,r,p<<1|1,x,y);
}
return pushup(query(l,mid,p<<1,x,y),query(mid+1,r,p<<1|1,x,y));
}
bool check(int x) {
build(1,m,1);
modify(1,m,1,idx[0],0,0);
for(int i=1;i<n;i++) {
int p=lower_bound(b+1,b+m+1,sum[i]-x)-b;
if(p>m) {
continue;
}
seg res=query(1,m,1,p,m);
res.l++;res.r++;
if(i==n-1) {
if(res.l<=k&&k<=res.r) {
return 1;
}
else {
return 0;
}
}
modify(1,m,1,idx[i],res.l,res.r);
}
return 0;
}
int main() {
cin>>n>>k;
for(int i=1;i<=n;i++) {
a[i]=read();
sum[i]=sum[i-1]+a[i];
b[i]=sum[i];
}
n++;
b[n]=0;
sort(b+1,b+n+1);
m=unique(b+1,b+n+1)-b-1;
for(int i=0;i<n;i++) {
idx[i]=lower_bound(b+1,b+m+1,sum[i])-b;
}
int l=-1e9,r=1e9;
while(l+1<r) {
int mid=(l+r)>>1;
if(check(mid)) {
r=mid;
}
else {
l=mid;
}
}
if(check(l)) {
cout<<l<<endl;
}
else {
cout<<r<<endl;
}
return 0;
}