【PR #12】划分序列 / Yet Another Mex Problem 题解
题目大意
给定一个长度为 \(n\) 的序列 \(a\),定义一段区间的价值为该区间的 \(\operatorname{mex}\) 乘上区间元素总和。
你需要将序列划分成若干个长度 \(\leq k\) 的区间。一个划分方案的价值为划分出来的每个区间价值之和,求所有划分方案的价值最大值。
\(1 \leq k \leq n \leq 2\times 10^5,0 \leq a_i \leq n\)。
题目分析
第一眼:这不直接维护 \(\operatorname{mex}\) 连续段之后对每段维护凸包,然后用 deque+启发式合并就行了吗?
实际上不行,因为从左往右 dp 的时候连续段会分裂。
记 \(s\) 为 \(a\) 的前缀和,先把 dp 的转移写出来:\(f_i=\max\limits_{i-k+1\leq j \leq i}{(f_{j-1}+\operatorname{mex}(a_j,\ldots,a_i)\times(s_i-s_{j-1}))}\)。
然后从左往右维护以 \(i\) 为结尾的 \(\operatorname{mex}\) 连续段,对 \(\operatorname{mex}\) 相同的一段区间 \([l,r]\) 放一起考虑。
将转移拆成 \((f_{j-1}-\operatorname{mex}\times s_{j-1})+(\operatorname{mex}\times s_i)\),然后分成两个部分。对于前面的式子,只需要将 \([l,r]\) 内的所有点 \((s_{j-1},f_{j-1})\) 建个凸包,然后查斜率为 \(\operatorname{mex}\) 的最大值即可,设其为 \(v_{\operatorname{mex}}\)。将所有连续段的贡献求出来后,后面的式子相当于求 \(\max\limits_{\operatorname{mex}}{(v_{\operatorname{mex}}+\operatorname{mex}\times s_i)}\),也可以看成将所有点 \((-\operatorname{mex},v_{\operatorname{mex}})\) 建个凸包,然后查斜率为 \(s_i\) 的最大值。
众所周知,整个过程只会有 \(\mathcal O(n)\) 个 \(\operatorname{mex}\) 连续段,将其预处理出来,每个连续段 \((l,r,x,m)\) 表示所有左端点在 \([l,r]\) 内,右端点 \(\geq x\) 的区间 \(\operatorname{mex}\) 值 \(\geq m\)(由于是最大值,所以算小了肯定不优)。
然后考虑 \(k\) 的限制,这相当于能贡献到 \(i\) 的为若干整连续段+某一段的后缀。将后缀单独考虑,然后就变成了若干个整段可以给 \(i\) 贡献,用刚才预处理出来的 \((l,r,x,m)\) 表示即为 \([l,r]\) 可以给 \([x,l+k]\) 贡献,\(\operatorname{mex}\) 是 \(m\)。
第一部分是简单的,直接用线段树维护凸包。对于线段树节点 \([L,R]\),等到 \(f_L\sim f_R\) 求出来后在该节点上建凸包,由于凸包上的点本身是按照横坐标排序的,所以建凸包的时间复杂度是 \(\mathcal O(n\log n)\)。查询的时候由于随着时间的推移,以每个位置为左端点的 \(\operatorname{mex}\) 肯定单调不降,所以直接对线段树上每个凸包维护一个指针即可,查询的总时间复杂度 \(\mathcal O(n\log n)\)。
然后是第二部分,这相当于对每个 \((l,r,x,m)\) 求出 \([l,r]\) 第一部分的答案 \(v_m\) 后,将 \((-m,v_m)\) 这个点贡献到 \([x,l+k]\) 这个区间,然后枚举到 \(i\) 的时候查询 \(i\) 这个位置上凸包在斜率为 \(s_i\) 时的最大值。比较简单的做法是把 \((-m,v_m)\) 看成直线,然后用线段树+李超线段树维护,不过时间复杂度是 \(\mathcal O(n\log^2 n)\) 的,不太行。
还是考虑用线段树维护凸包,由于查询时 \(s_i\) 单调不降,所以如果能建出凸包,那么查询的时候就只需要维护指针,还是单 \(\log\);关键是将 \((-m,v_m)\) 塞到线段树节点的时候并不能保证 \(m\) 单增,这导致插入凸包时多一个 \(\log\)。
不过由于插入的点的横坐标 \(m\) 是提前知道的,所以可以离线。先将所有 \((l,r,x,m)\) 按照 \(m\) 排序,然后按顺序依次插入线段树上 \([x,l+k]\) 这段区间,这样就可以把每个线段树节点将要插入的点提前按照横坐标排序,同时处理出每个 \((l,r,x,m)\) 在每个插入的线段树节点上的排名。
由于对于所有的 \((l,r,x,m)\),都有 \(r\leq x\),所以遍历到每个线段树节点 \([L,R]\) 时,所有可能贡献给当前节点的连续段都已经计算完毕,这时再用已经排好序的点建凸包即可。时间复杂度 \(\mathcal O(n\log n)\)。
吐槽一下:为什么单 \(\log\) 跑得比双 \(\log\) 慢这么多啊/ll/ll/ll
代码
#include<bits/stdc++.h>
using namespace std;
using namespace my_std;
#define LC x<<1
#define RC x<<1|1
set<pair<int,pair<int,int> > > s;
set<pair<pair<int,int>,int> > ss;
vector<int> vec[200020],qry[800080];
vector<pair<int,int> > nd[600060];
int n,k,a[200020],pre[200020],mp[200020],cnt=0;
ll f[200020],sum[200020],pos[800080];
struct node{
int l,r,x,mex;
}b[600060],c[200020];
struct point{
ll x,y;
};
vector<point> tree[800080];
il bl operator<(const node &x,const node &y){
return x.mex<y.mex;
}
il point operator-(const point &x,const point &y){
return (point){x.x-y.x,x.y-y.y};
}
il ll operator*(const point &x,const point &y){
return x.x*y.y-x.y*y.x;
}
struct seg{
vector<point> vec[800080];
ll pos[800080];
il void pushup(ll x){
vec[x]=vec[LC];
fr(i,0,(ll)vec[RC].size()-1){
while((ll)vec[x].size()>1&&(vec[RC][i]-vec[x][(ll)vec[x].size()-1])*(vec[x][(ll)vec[x].size()-1]-vec[x][(ll)vec[x].size()-2])<=0) vec[x].pop_back();
vec[x].push_back(vec[RC][i]);
}
pos[x]=(ll)vec[x].size()-1;
}
void mdf(ll x,ll l,ll r,ll v){
if(l==r){
vec[x].push_back((point){sum[l],f[l]});
return;
}
ll mid=(l+r)>>1;
if(v<=mid) mdf(LC,l,mid,v);
else mdf(RC,mid+1,r,v);
if(v==r) pushup(x);
}
ll query(ll x,ll l,ll r,ll ql,ll qr,ll v){
if(ql<=l&&r<=qr){
while(pos[x]&&(vec[x][pos[x]]-vec[x][pos[x]-1])*(point){1,v}>=0) pos[x]--;
if(pos[x]<(ll)vec[x].size()) return vec[x][pos[x]].y-v*vec[x][pos[x]].x;
else return -inf;
}
ll mid=(l+r)>>1,res=-inf;
if(ql<=mid) res=max(res,query(LC,l,mid,ql,qr,v));
if(mid<qr) res=max(res,query(RC,mid+1,r,ql,qr,v));
return res;
}
}T;
il void pushup(ll x){
ll top=0;
fr(i,0,(ll)tree[x].size()-1){
while(top>1){
if(tree[x][i].x==tree[x][top-1].x){
if(tree[x][i].y>tree[x][top-1].y) top--;
else break;
}
else if((tree[x][i]-tree[x][top-1])*(tree[x][top-1]-tree[x][top-2])<=0) top--;
else break;
}
tree[x][top++]=tree[x][i];
}
tree[x].resize(top);
pos[x]=max(0ll,top-1);
}
void ins(ll x,ll l,ll r,ll ql,ll qr,ll v){
if(ql>qr) return;
if(ql<=l&&r<=qr){
qry[x].push_back(v);
tree[x].push_back((point){0,0});
nd[v].push_back(MP(x,(ll)qry[x].size()-1));
return;
}
ll mid=(l+r)>>1;
if(ql<=mid) ins(LC,l,mid,ql,qr,v);
if(mid<qr) ins(RC,mid+1,r,ql,qr,v);
}
ll query(ll x,ll l,ll r,ll v,ll w){
while(pos[x]&&(tree[x][pos[x]]-tree[x][pos[x]-1])*(point){1,w}>=0) pos[x]--;
ll res=-inf;
if(pos[x]<(ll)tree[x].size()) res=tree[x][pos[x]].y-w*tree[x][pos[x]].x;
if(l==r) return res;
ll mid=(l+r)>>1;
if(v<=mid) res=max(res,query(LC,l,mid,v,w));
else res=max(res,query(RC,mid+1,r,v,w));
return res;
}
void solve(ll x,ll l,ll r){
pushup(x);
if(l==r){
f[l]=max(f[l],query(1,0,n,l,sum[l]));
if(c[l].l){
ll tmp=T.query(1,0,n,c[l].l-1,c[l].r-1,c[l].mex);
f[l]=max(f[l],tmp+c[l].mex*sum[l]);
}
T.mdf(1,0,n,l);
fr(i,0,(ll)vec[l+1].size()-1){
ll id=vec[l+1][i],tmp=T.query(1,0,n,b[id].l-1,b[id].r-1,b[id].mex);
fr(j,0,(ll)nd[id].size()-1) tree[nd[id][j].fir][nd[id][j].sec]=(point){-b[id].mex,tmp};
}
}
else{
ll mid=(l+r)>>1;
solve(LC,l,mid);
solve(RC,mid+1,r);
}
}
int main(){
n=read();
k=read();
fr(i,1,n){
a[i]=read();
pre[i]=mp[a[i]];
mp[a[i]]=i;
}
ll lst=n;
fr(i,0,n){
if(mp[i]<lst){
s.insert(MP(i,MP(mp[i]+1,lst)));
ss.insert(MP(MP(mp[i]+1,lst),i));
lst=mp[i];
}
}
pfr(i,n,1){
set<pair<pair<int,int>,int> >::iterator jt=ss.upper_bound(MP(MP(i-k+1,n+1),n+1));
if(jt!=ss.begin()){
jt--;
if((*jt).fir.sec>=(i-k+1)) c[i]=(node){i-k+1,(*jt).fir.sec,i,(*jt).sec};
}
set<pair<int,pair<int,int> > >::iterator it=s.begin();
pair<int,pair<int,int> > now=*it;
b[++cnt]=(node){now.sec.fir,i,i,now.fir};
s.erase(now);
ss.erase(MP(now.sec,now.fir));
if(now.sec.fir<i){
s.insert(MP(now.fir,MP(now.sec.fir,i-1)));
ss.insert(MP(MP(now.sec.fir,i-1),now.fir));
}
mp[a[i]]=pre[i];
it=s.lower_bound(MP(a[i]+1,MP(0,0)));
ll tmp=0;
if(it!=s.end()) tmp=(*it).sec.sec;
while(it!=s.end()){
now=*it;
if(now.sec.sec<=mp[a[i]]) break;
s.erase(now);
ss.erase(MP(now.sec,now.fir));
b[++cnt]=(node){now.sec.fir,now.sec.sec,i,now.fir};
if(now.sec.fir<=mp[a[i]]){
s.insert(MP(now.fir,MP(now.sec.fir,mp[a[i]])));
ss.insert(MP(MP(now.sec.fir,mp[a[i]]),now.fir));
break;
}
it=s.lower_bound(MP(a[i]+1,MP(0,0)));
}
if(mp[a[i]]<tmp){
s.insert(MP(a[i],MP(mp[a[i]]+1,tmp)));
ss.insert(MP(MP(mp[a[i]]+1,tmp),a[i]));
}
}
fr(i,1,n) sum[i]=sum[i-1]+a[i];
sort(b+1,b+cnt+1);
fr(i,1,cnt) vec[b[i].x].push_back(i);
pfr(i,cnt,1) ins(1,0,n,b[i].x,min(n,b[i].l+k-1),i);
solve(1,0,n);
write(f[n]);
}