【STSRM12】夏令营
【题意】n个数划分成k段,每段的价值为段内不同数字的数量,求最大总价值
【算法】DP+线段树
【题解】
f[i][j]表示前i个数字划分成j段的最大价值。
f[i][j]=max(f[k][j-1]+value(k+1,j)),j-1<=k<i。
暴力复杂度O(n^3*k),预处理value后复杂度降为O(n^2*k)。
正解考虑加入一个数字i,只能为k+1~i贡献1的价值,其中k为数字i上一次出现的位置。
那么排序预处理上一次出现位置,区间+1用线段树维护,取max用线段树查询,复杂度O(nk*log n)。
注意线段树常数……特别注意手写max。
从头查比从中间查常数小。
#include<cstdio> #include<algorithm> #include<cstring> using namespace std; const int maxn=36010; struct tree{int l,r,ms,delta;}t[maxn*4]; struct cyc{int num,ord;}b[maxn]; bool cmp(cyc a,cyc b){return a.num<b.num||(a.num==b.num&&a.ord<b.ord);} int f[maxn],n,kind,a[maxn],c[maxn]; void pushup(int k){t[k].ms=max(t[k<<1].ms,t[k<<1|1].ms);} void pushdown(int k){ if(t[k].delta){ t[k<<1].delta+=t[k].delta; t[k<<1].ms+=t[k].delta; t[k<<1|1].delta+=t[k].delta; t[k<<1|1].ms+=t[k].delta; t[k].delta=0; } } void build(int k,int l,int r){ t[k].l=l;t[k].r=r; if(l==r){t[k].ms=f[l];t[k].delta=0;} else{ int mid=(l+r)>>1; build(k<<1,l,mid); build(k<<1|1,mid+1,r); pushup(k); t[k].delta=0; } } void insert(int k,int l,int r,int x){ if(l<=t[k].l&&t[k].r<=r){t[k].delta+=x;t[k].ms+=x;} else{ pushdown(k); int mid=(t[k].l+t[k].r)>>1; if(l<=mid)insert(k<<1,l,r,x); if(r>mid)insert(k<<1|1,l,r,x); pushup(k); } } int query(int k,int l,int r){ if(l<=t[k].l&&t[k].r<=r)return t[k].ms; else{ pushdown(k); int mid=(t[k].l+t[k].r)>>1; int ans=0; if(l<=mid)ans=query(k<<1,l,r); if(r>mid)ans=max(ans,query(k<<1|1,l,r)); return ans; } } int main(){ scanf("%d%d",&n,&kind); for(int i=1;i<=n;i++){scanf("%d",&a[i]);b[i].num=a[i];b[i].ord=i;} sort(b+1,b+n+1,cmp); for(int i=1;i<=n;i++)if(b[i].num==b[i-1].num)c[b[i].ord]=b[i-1].ord; build(1,0,n); for(int j=1;j<=kind;j++){ for(int i=1;i<=n;i++){ insert(1,c[i],i-1,1); f[i]=query(1,0,i-1); } build(1,0,n); } printf("%d",f[n]); return 0; }
暴力AC法:打表容易发现具有决策单调性,原因在于加入i时已知i-1的决策点j比左边优,加入i后对i上一次出现的位置到i有贡献,如果左边有贡献则j一定有贡献,所以决策点k>=j。
决策单调性可以用分支决策维护,复杂度O(n^2*k*log n)。
瓶颈在快速求区间数字个数,用主席树维护,复杂度O(nk*log n)。
#include<cstdio> #include<cstring> #include<cctype> #include<cmath> #include<algorithm> #define ll long long using namespace std; int read() { char c;int s=0,t=1; while(!isdigit(c=getchar()))if(c=='-')t=-1; do{s=s*10+c-'0';}while(isdigit(c=getchar())); return s*t; } /*------------------------------------------------------------*/ const int inf=0x3f3f3f3f,maxn=40010; int n,kind,f[2][maxn],a[maxn],x,dfsnum=0,b[maxn]; bool c[maxn]; int work(int x,int y){ dfsnum++; int ans=0; for(int i=x;i<=y;i++)if(b[a[i]]!=dfsnum){ b[a[i]]=dfsnum; ans++; } return ans; } void find(int l,int r,int L,int R) { if(l>r||L>R)return; int mid=(l+r)>>1; int w,v=-inf+1; for(int i=L;i<=R&&i<mid;i++){ if(f[1-x][i]+work(i+1,mid)>v){ v=f[1-x][i]+work(i+1,mid); w=i; } } f[x][mid]=v; find(l,mid-1,L,w); find(mid+1,r,w,R); } int main() { scanf("%d%d",&n,&kind); for(int i=1;i<=n;i++)scanf("%d",&a[i]); x=1; for(int j=1;j<=kind;j++){ x=1-x; f[x][0]=-inf; find(1,n,0,n-1); } printf("%d",f[x][n]); return 0; }