[BZOJ4299]Codechef FRBSUM
[BZOJ4299]Codechef FRBSUM
题面
数集S的ForbiddenSum定义为无法用S的某个子集(可以为空)的和表示的最小的非负整数。
例如,S={1,1,3,7},则它的子集和中包含0(S’=∅),1(S’={1}),2(S’={1,1}),3(S’={3}),4(S’={1,3}),5(S' = {1, 1, 3}),但是它无法得到6。因此S的ForbiddenSum为6。
给定一个序列A,你的任务是回答该数列的一些子区间所形成的数集的ForbiddenSum是多少。
Input
输入数据的第一行包含一个整数N,表示序列的长度。
接下来一行包含N个数,表示给定的序列A(从1标号)。
接下来一行包含一个整数M,表示询问的组数。
接下来M行,每行一对整数,表示一组询问。
Output
对于每组询问,输出一行表示对应的答案。
Sample Input
5 1 2 4 9 10 5 1 1 1 2 1 3 1 4 1 5
Sample Output
2 4 8 8 8
Hint
对于100%的数据,1≤N,M≤100000,1≤A_i≤10^9 ,1≤A_1+A_2+…+A_N≤10^9。
思路
我们首先考虑一下,如果一个元素新加入一个数列,会发生什么变化。设当前能凑出的最大值为\(max\),现在加入元素\(a\),显然,将每个可以凑出的值加上\(a\)就得到了可以包含\(a\)的子集的和的值域:[a,max+a],如果\(a \leq max+1\),那就可以更新max为\(max+a\)。如果\(a >max+1\)那么很显然,如果加入这个元素对现在的\(max\)值无影响,但是对未来的\(max\)可能会有影响,所以不能直接舍去。
那么对于一次操作,我们可以写一下伪代码:
int max=0;
while(1){
int sum_now=the sum of {a|a<=max+1};//相当于把所有小于等于max+1一次加入
if(sum_now>max){//更新max
max=sum_now;
}
else{//已更新完所有的数,无法再次更新
break;
}
}
那么现在这道题就转变为了在规定区间内,每次操作求[1,max+1]的和,并多次操作,直到无法操作为止。
那解法就很显然了,建立一棵权值线段树并可持久化,利用前缀和的方式维护区间查询。因为这道题涉及数值加减,所以不能直接离散化,而是要做一个结构体,把离散化后的值(所在叶子节点序号)和原数值都存一下。离散化后的值用于定位叶子节点,而线段树维护的是原数值的和。
复杂度\(N \log N + M \log P \log N\),其中\(P\)为\(\sum_{i=1}^{n} {a_i}\)。证明有点麻烦,过程可以看下面的链接。
代码
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;
#define maxn (int)(1e5+1000)
#define ll long long
int n,m,idx,len,rt[maxn<<5],a[maxn],sorted[maxn],ls[maxn<<5],rs[maxn<<5];
ll sum[maxn<<5];
int build(int l,int r){
int now=++idx;
if(l==r)return now;
int mid=(l+r)>>1;
ls[now]=build(l,mid);
rs[now]=build(mid+1,r);
return now;
}
int update(int last,int num,int l,int r){
int now=++idx;
sum[now]=sum[last]+sorted[num],ls[now]=ls[last],rs[now]=rs[last];
int mid=(l+r)>>1;
if(l==r){return now;}
if(num<=mid){
ls[now]=update(ls[last],num,l,mid);
}
else{
rs[now]=update(rs[last],num,mid+1,r);
}
return now;
}
int getid(int x){
return lower_bound(sorted+1,sorted+1+len,x)-sorted;
}
ll get_sum(int pre,int now,int l,int r,int tl,int tr){
if(tr<tl)return 0;
if(tl<=l&&r<=tr){return sum[now]-sum[pre];}
int mid=(l+r)>>1;
ll ans=0;
if(tl<=mid){ans+=get_sum(ls[pre],ls[now],l,mid,tl,tr);}
if(tr>=mid+1){ans+=get_sum(rs[pre],rs[now],mid+1,r,tl,tr);}
return ans;
}
int main(){
//freopen("in","r",stdin);
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);sorted[i]=a[i];
}
sort(sorted+1,sorted+1+n);
len=unique(sorted+1,sorted+1+n)-sorted-1;
rt[0]=build(1,n);
for(int i=1;i<=n;i++){
rt[i]=update(rt[i-1],getid(a[i]),1,len);
}
scanf("%d",&m);
for(int i=1;i<=m;i++){
int l,r;scanf("%d%d",&l,&r);
ll tot_max=0,sum_now=0;
while(1){
int id=upper_bound(sorted+1,sorted+1+len,tot_max+1)-sorted-1;
if(id>len)id=len;
sum_now=get_sum(rt[l-1],rt[r],1,len,1,id);
if(sum_now>tot_max){
tot_max=sum_now;
}
else if(sum_now==tot_max){
break;
}
}
tot_max++;
printf("%lld\n",tot_max);
}
return 0;
}