[CF833B] The Bakery

Description

将一个长度为n的序列分为k段

使得总价值最大一段区间的价值表示为区间内不同数字的个数

\(n\leq 35000,k\leq 50,1\leq a_i\leq n\)

Solution

定义 \(dp[i][j]\) 表示前 i 个里面分 j 段的最大收益

一个显然的 dp 方程是 \(dp[i][j]=\max \limits_{1\leq p<i} dp[p][j-1]+w(p+1,i)\)。复杂度 \(O(n^2k)\),GG。

考虑优化此方程,因为是取 max,容易想到放在线段树上实现。

同时定义 \(pre[a[i]]\) 表示当前 \(a[i]\) 这个元素上一次出现的位置是哪里,如果没有出现则是 0 。

难点在于 \(w\) 数组如何动态快速的求出来,我们外层循环一个 \(j\) 表示分的段数,发现如果当前扫到 i 这个位置那么 a[i] 的贡献实际上是让 \([pre[a[i]],i]\) 这段区间整体加一。可以这么理解,就是当前扫到 i,那么对于所有到 i 截至的区间 \([p,i]\),a[i] 这个元素对这些区间有贡献的部分是左端点\(\in [pre[a[i]],i]\) 里的这一段。线段树区间加就好了。也就是说,当前扫到了 i ,那么线段树的叶子节点 p 表示的就是 \(w[p,i]\) 的值,这也是我们用线段树的意义所在。这样就可以 \(O(nlogn)\) 求出 w 数组了。同时 dp 数组实时更新即可。

还有一点要注意的是方程是 \(dp[p][j-1]+w(p+1,i)\) ,也就是说能用来更新答案的是 节点 p 的 dp 值和 p+1 的累加值,有点麻烦,干脆把所有的 dp 值都往左挪一个就行了,也就是叶子节点 p 表示的实际上是 p+1 的值。感觉有点绕。。。

Code

#include<cstdio>
#include<cctype>
#include<cstring>
#define K 55
#define N 35005
#define min(A,B) ((A)<(B)?(A):(B))
#define max(A,B) ((A)>(B)?(A):(B))
#define swap(A,B) ((A)^=(B)^=(A)^=(B))

int n,k;
int f[N];
int val[N];
int pre[N];
int mx[N<<2];
int lazy[N<<2];

int getint(){
	int x=0,f=0;char ch=getchar();
	while(!isdigit(ch)) f|=ch=='-',ch=getchar();
	while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
	return f?-x:x;
}

void build(int cur,int l,int r){
	if(l==r){
		mx[cur]=f[l-1];
		return;
	}
	int mid=l+r>>1;
	build(cur<<1,l,mid);
	build(cur<<1|1,mid+1,r);
	mx[cur]=max(mx[cur<<1],mx[cur<<1|1]);
}

void pushdown(int cur){
	if(!lazy[cur]) return;
	lazy[cur<<1]+=lazy[cur];
	lazy[cur<<1|1]+=lazy[cur];
	mx[cur<<1]+=lazy[cur];
	mx[cur<<1|1]+=lazy[cur];
	lazy[cur]=0;
}

void modify(int cur,int l,int r,int ql,int qr){
	if(!ql or !qr or ql>qr) return;
	if(ql<=l and r<=qr){
		mx[cur]++;
		lazy[cur]++;
		return;
	}
	pushdown(cur);
	int mid=l+r>>1;
	if(ql<=mid)
		modify(cur<<1,l,mid,ql,qr);
	if(mid<qr)
		modify(cur<<1|1,mid+1,r,ql,qr);
	mx[cur]=max(mx[cur<<1],mx[cur<<1|1]);
}

int query(int cur,int l,int r,int ql,int qr){
	if(ql<=l and r<=qr)
		return mx[cur];
	pushdown(cur);
	int mid=l+r>>1,ans=0;
	if(ql<=mid){
		int p=query(cur<<1,l,mid,ql,qr);
		ans=max(ans,p);
	}
	if(mid<qr){
		int p=query(cur<<1|1,mid+1,r,ql,qr);
		ans=max(ans,p);
	}
	return ans;
}

signed main(){
	n=getint(),k=getint();
	for(int i=1;i<=n;i++)
		val[i]=getint();
	for(int j=1;j<=k;j++){
		memset(mx,0,sizeof mx);
		memset(pre,0,sizeof pre);
		memset(lazy,0,sizeof lazy);
		build(1,1,n);
		for(int i=1;i<=n;i++){
			modify(1,1,n,pre[val[i]]+1,i);
			pre[val[i]]=i;
			//if(i<j) continue;
			f[i]=query(1,1,n,1,i);
			//printf("j=%d,i=%d,f=%d\n",j,i,f[i]);
		}
	}
	printf("%d\n",f[n]);
	return 0;
}
posted @ 2018-07-11 21:13  YoungNeal  阅读(202)  评论(0编辑  收藏  举报