bzoj 4345: [POI2016]Korale
Description
有n个带标号的珠子,第i个珠子的价值为a[i]。现在你可以选择若干个珠子组成项链(也可以一个都不选),项链的价值为所有珠子的价值和。现在给所有可能的项链排序,先按权值从小到大排序,对于权值相同的,根据所用珠子集合的标号的字典序从小到大排序。请输出第k小的项链的价值,以及所用的珠子集合。
Input
第一行包含两个正整数n,k(1<=n<=1000000,1<=k<=min(2^n,1000000))。
第二行包含n个正整数,依次表示每个珠子的价值a[i](1<=a[i]<=10^9)。
Output
第一行输出第k小的项链的价值。
第二行按标号从小到大依次输出该项链里每个珠子的标号。
Sample Input
4 10
3 7 4 3
3 7 4 3
Sample Output
10
1 3 4
1 3 4
HINT
Source
用堆来求k优解是一个很常用的方法了,我们先排序,堆中存入二元组(sum,i),表示和为sum,最大的元素的编号为i,
那么每次取出(sum,i),把(sum+a[i+1],i+1)和(sum-a[i]+a[i+1],i+1)丢入堆中即可;
然后我们考虑如何求出字典序,考虑用dfs来实现,假设dfs传的参为(x,sum),那么我们每次都是从(x+1,n)中最小的满足a[i]<=sum的i开始搜索,这样就不用枚举x+1-n了;
这样满足dfs求字典序的搜索顺序;因为我们只会搜索到k个,所以复杂度是对的;上面那个问题我们可以在线段树上进行查询;
线段树上维护区间最小值,然后在线段树上二分即可;
//MADE BY QT666 #include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<queue> #define lson x<<1 #define rson x<<1|1 using namespace std; typedef long long ll; const int N=1000050; int n,k; ll ans[N],a[N],b[N],tr[N*4],Mn[N*4]; struct data{ ll sum;int j; }; bool operator < (const data &a,const data &b){ return a.sum>b.sum; } priority_queue<data> Q; int zhan[N],tot,tt,K; void build(int x,int l,int r){ if(l==r) {tr[x]=l,Mn[x]=b[l];return;} int mid=(l+r)>>1; build(lson,l,mid);build(rson,mid+1,r); if(tr[lson]) tr[x]=tr[lson]; if(tr[rson]) tr[x]=min(tr[x],tr[rson]); Mn[x]=min(Mn[lson],Mn[rson]); } int query(int x,int l,int r,int xl,int xr,ll v){ if(l==r){ if(Mn[x]<=v) return l; else return n+1; } if(xl<=l&&r<=xr){ int mid=(l+r)>>1; if(Mn[x]>v) return n+1; else if(Mn[lson]<=v) return query(lson,l,mid,xl,mid,v); else return query(rson,mid+1,r,mid+1,xr,v); } int mid=(l+r)>>1; if(xr<=mid) return query(lson,l,mid,xl,xr,v); else if(xl>mid) return query(rson,mid+1,r,xl,xr,v); else return min(query(lson,l,mid,xl,mid,v),query(rson,mid+1,r,mid+1,xr,v)); } void dfs(int x,ll sum){ if(K>=tt) return; if(!sum){ K++; if(K==tt) for(int i=1;i<=tot;i++) printf("%d ",zhan[i]); return; } if(x==n) return; for(int i=x+1;i<=n;i++){ i=query(1,1,n,i,n,sum); if(i<=n){ zhan[++tot]=i;dfs(i,sum-b[i]);tot--; } else return; } } int main(){ scanf("%d%d",&n,&k); for(int i=1;i<=n;i++) scanf("%lld",&a[i]),b[i]=a[i]; sort(a+1,a+1+n);Q.push((data){a[1],1});k--; for(int i=1;i<=k;i++){ data x=Q.top();Q.pop();ans[i]=x.sum; if(x.j+1<=n) Q.push((data){x.sum+a[x.j+1],x.j+1}); if(x.j+1<=n) Q.push((data){x.sum-a[x.j]+a[x.j+1],x.j+1}); } while(ans[k]==ans[k-(tt+1)+1]) tt++; printf("%lld\n",ans[k]);build(1,1,n); dfs(0,ans[k]); return 0; }