BZOJ1584: [Usaco2009 Mar]Cleaning Up 打扫卫生
n<=40000个<=m<=n的数,一段数不和谐(河蟹???)度为该段中不同数的个数的平方,求把n个数划成若干段后的最小不和谐度。
好题。首先可以确定是DP,f[i]=min(f[j]+P(j+1,i)),其中P(l,r)表示区间l到r的不同数的个数的平方。n2,过不了。
不过可以发现f是不下降的。证明略。
方法一:观察样例中的f数组,发现在i-1*1,i-2*2,……到i的区间中的数决定了f[i]这一步需不需要+1,于是开sqrt(n)个指针来扫这些区间。i-j*j到i的区间中的个数<j,f[i]就可以不加1。至于统计个数,开个200*40000的数组貌似是没问题的。
1 #include<stdio.h> 2 #include<string.h> 3 #include<algorithm> 4 #include<cstdlib> 5 #include<math.h> 6 //#include<iostream> 7 using namespace std; 8 9 int n,m; 10 #define maxn 40011 11 int cnt[205][maxn],a[maxn],pos[maxn],f[maxn]; 12 int main() 13 { 14 scanf("%d%d",&n,&m);m=(int)sqrt(n)+1; 15 for (int i=1;i<=n;i++) scanf("%d",&a[i]); 16 for (int i=1;i<=m;i++) pos[i]=-i*i; 17 f[0]=0; 18 memset(cnt,0,sizeof(cnt)); 19 for (int i=1;i<=n;i++) 20 { 21 for (int j=1;j<=m;j++) 22 { 23 if (pos[j]>0) 24 { 25 cnt[j][a[pos[j]]]--; 26 if (cnt[j][a[pos[j]]]==0) cnt[j][0]--; 27 } 28 pos[j]++; 29 cnt[j][a[i]]++; 30 if (cnt[j][a[i]]==1) cnt[j][0]++; 31 } 32 bool ok=0; 33 for (int j=1;j<=m;j++) 34 if (pos[j]>0 && cnt[j][0]<=j) ok=1; 35 f[i]=f[i-1]+(!ok); 36 } 37 printf("%d\n",f[n]); 38 return 0; 39 }
错误!未能理解f[i]不+1的原因,错误地把原因归咎为“a[i](第i位的数)可以和前面的某个状态合起来一起算”。我们找到的a[i]和前面某个状态合起来算的方案不一定就是最优解,有反例:
10 5
2 3 2 2 4 4 4 2 1 1
正确答案6,错误答案5。
方法二:最后答案不超过40000,所以这个不同个数的平方<=40000,不同个数<=200,也就是说,我们只要用“不同个数<=200的区间”来做P(l,r)。
说人话就是f[i]=min(f[pos[j]-1]+j*j),其中j<=sqrt(n),pos[j]表示满足pos[j]到i的区间内不同个数不超过j的最小下标。
如何维护pos[j]呢?首先遇到一个新的数a[i],我们需要知道a[i]是不是对于pos[j]到i-1的一个新的数,所以需记last[i]表示数值i的最后一个下标,用last[a[i]]和pos[j]比较就可以更新cnt[j],表示pos[j]到i中实际的不同个数。
然后如何调整pos[j]的左端呢?注意到cnt[j]>j是不合法的,这时候需要将pos[j]右移,直到出现某个last[a[pos[j]]]=pos[j],也就是把一个数灭绝了。这样均摊的时间在n*sqrt(n)范围内。然后再用pos[j]更新f[i]即可。
1 #include<stdio.h> 2 #include<string.h> 3 #include<algorithm> 4 #include<cstdlib> 5 #include<math.h> 6 //#include<iostream> 7 using namespace std; 8 9 int n,m; 10 #define maxn 40011 11 int pos[maxn],last[maxn],cnt[maxn],a[maxn],f[maxn]; 12 int main() 13 { 14 scanf("%d%d",&n,&m);m=(int)sqrt(n)+1; 15 for (int i=1;i<=n;i++) scanf("%d",&a[i]); 16 for (int i=1;i<=m;i++) pos[i]=1; 17 memset(last,0,sizeof(last)); 18 memset(cnt,0,sizeof(cnt)); 19 for (int i=1;i<=n;i++) 20 { 21 for (int j=1;j<=m;j++) if (last[a[i]]<pos[j]) cnt[j]++; 22 last[a[i]]=i; 23 for (int j=1;j<=m;j++) if (cnt[j]>j) 24 { 25 while (last[a[pos[j]]]!=pos[j]) pos[j]++; 26 pos[j]++; 27 cnt[j]--; 28 } 29 f[i]=n; 30 for (int j=1;j<=m;j++) f[i]=min(f[i],f[pos[j]-1]+j*j); 31 } 32 printf("%d\n",f[n]); 33 return 0; 34 }