[Luogu4331][Baltic2004]数字序列
原题请见《左偏树的特点及其应用》BY 广东省中山市第一中学 黄源河
luogu
题意
给出序列\(a[1...n]\),要求构造一个不下降序列\(b[1...n]\)使得\(\sum_{i=1}^{n}|a_i-b_i|\)最小。
sol
首先很自然地能够想到,构造出来的序列\(b[1...n]\)一定可以划分成\(m\)段\((1\le{m}\le{n})\),每段内数字全部相同。
我们把每一段的数字提取出来分别为\(c[1...m]\)。
如果对每一段的\(c[i]\)都取最优的话,那么一定是去这一段中\(a[i]\)的中位数。
但是取中位数可能会导致序列\(c\)不满足非降,这个时候就需要把相邻的两个不合法的段合并成一段。
所以就需要维护中位数。
左偏树。对于一个长度为\(x\)的段,左偏树中保存这一段中前\(\lfloor\frac{x+1}{2}\rfloor\)小的数字,易知这些数里面最大的那个就是中位数,合并的时候直接合并两棵左偏树。
因为\(\lfloor\frac{x+1}{2}\rfloor+\lfloor\frac{y+1}{2}\rfloor=\lfloor\frac{x+y+1}{2}\rfloor-1\)当且仅当\(x,y\)均为奇数,所以这种情况要弹掉堆顶元素。
复杂度\(O(n\log{n})\)
注:洛谷的题目是要求构造一个递增序列,可以采用减下标的方法,即输入时把每个数都减去对应下表,输出时再加上,这样就可以完成不下降序列和递增序列的转换。
code
#include<cstdio>
#include<algorithm>
using namespace std;
#define ll long long
inline int gi()
{
int x=0,w=1;char ch=getchar();
while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
if (ch=='-') w=0,ch=getchar();
while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return w?x:-x;
}
const int N = 1e6+5;
int n,key[N],ls[N],rs[N],dis[N],S[N],ed[N],top;
ll ans;
int merge(int A,int B)
{
if (!A||!B) return A|B;
if (key[A]<key[B]) swap(A,B);
rs[A]=merge(rs[A],B);
if (dis[ls[A]]<dis[rs[A]]) swap(ls[A],rs[A]);
dis[A]=dis[rs[A]]+1;
return A;
}
int main()
{
n=gi();
for (int i=1;i<=n;++i) key[i]=gi();
for (int i=1;i<=n;++i)
{
++top;S[top]=i;ed[top]=i;
while (top>1&&key[S[top]]<key[S[top-1]])
{
--top;
S[top]=merge(S[top],S[top+1]);
if (((ed[top+1]-ed[top])&1)&&((ed[top]-ed[top-1])&1))
S[top]=merge(ls[S[top]],rs[S[top]]);
ed[top]=ed[top+1];
}
}
for (int i=1;i<=top;++i)
for (int j=ed[i-1]+1;j<=ed[i];++j)
ans+=abs(key[j]-key[S[i]]);
printf("%lld\n",ans);
return 0;
}
强行再贴一个洛谷上那道题的代码
#include<cstdio>
#include<algorithm>
using namespace std;
#define ll long long
inline int gi()
{
int x=0,w=1;char ch=getchar();
while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
if (ch=='-') w=0,ch=getchar();
while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return w?x:-x;
}
const int N = 1e6+5;
int n,key[N],ls[N],rs[N],dis[N],S[N],ed[N],top;
ll ans;
int merge(int A,int B)
{
if (!A||!B) return A|B;
if (key[A]<key[B]) swap(A,B);
rs[A]=merge(rs[A],B);
if (dis[ls[A]]<dis[rs[A]]) swap(ls[A],rs[A]);
dis[A]=dis[rs[A]]+1;
return A;
}
int main()
{
n=gi();
for (int i=1;i<=n;++i) key[i]=gi()-i;
for (int i=1;i<=n;++i)
{
++top;S[top]=i;ed[top]=i;
while (top>1&&key[S[top]]<key[S[top-1]])
{
--top;
S[top]=merge(S[top],S[top+1]);
if (((ed[top+1]-ed[top])&1)&&((ed[top]-ed[top-1])&1))
S[top]=merge(ls[S[top]],rs[S[top]]);
ed[top]=ed[top+1];
}
}
for (int i=1;i<=top;++i)
for (int j=ed[i-1]+1;j<=ed[i];++j)
ans+=abs(key[j]-key[S[i]]);
printf("%lld\n",ans);
for (int i=1;i<=top;++i)
for (int j=ed[i-1]+1;j<=ed[i];++j)
printf("%d ",key[S[i]]+j);
puts("");return 0;
}