bzoj1492 [NOI2007]货币兑换Cash
经典的1D1D动态规划题目,标准做法是平衡树维护凸壳,但实际上还有更简洁的分治法。
首先分析一下题目,对于任意一天,一定是贪心地买入所有货币或者卖出所有货币是最优的,因为有便宜我们就要尽量去占,有亏损就一点也不去碰。于是我们得到方程:
f[i]=max{f[j]/(a[j]*rate[j]+b[j])*rate[j]*a[i]+f[j]/(a[j]*rate[j]+b[j])*b[i]}
其中,x[j]=f[j]/(a[j]*rate[j]+b[j])*rate[j]表示第j天最多可以拥有的A货币的数量
y[j]=f[j]/(a[j]*rate[j]+b[j])表示第j天最多可以拥有的B货币的数量
那么方程可化简为f[i]=max{x[j]*a[i]+y[j]*b[i]},那么我们就是要选择一个最优的决策点(x[j],y[j])来更新f[i]得到最优解。
变形:y[j]=f[i]/b[i]-x[j]*a[i]/b[i],这是一个直线的斜截式方程,由于我们是用j去更新i,那么就相当于每次用一条斜率为-a[i]/b[i]的直线去切由若干(x[j],y[j])点组成的集合,能得到的最大截距的点,就是最优决策点,进一步,就是要维护一个由若干(x[j],y[j])点组成凸壳,因为最优决策点一定在凸壳上。
但是对于斜率-a[i]/b[i]和点(x[j],y[j])都是无序的,于是我们只能用一棵平衡树来维护凸壳,每次找到斜率能卡到的点(此点左侧的斜率和右侧的斜率恰好夹住-a[i]/b[i]斜率)。
具体splay实现:我们维护x坐标递增的点集,每次把新点插入到相应位置,更新凸壳的时候,分别找到新点左右能与它组成新的凸壳的点,把中间的点删掉;如果这个点完全在旧的凸壳内,那么把这个点删掉。每次找最优决策点的时候就拿-a[i]/b[i]去切凸壳就行了。
但这样搞实在是麻烦了许多,而且许多人得splay代码非常的长,在考场上就非常不容易写出来,于是出现了神一般的陈丹琪分治!
这个神级分治的精髓在于:变在线为离线,化无序为有序。
上面我们分析了,因为点和斜率都不是单调的,所以我们只能用一棵平衡树去维护。我们考虑导致无序的原因,是我们按照顺序依次回答了1..n的关于f值的询问。但是事实上我们并没有必要这么做,因为每个1..n的f[i]值,可能成为最优决策点一定在1..i范围内,而对于每个在1..i范围内的决策点,一定都有机会成为i+1..n的f值得最优决策点。这样1..i的f值一定不会受1..i的决策点的影响,i+1..n的点一定不会i+1..n的f值。于是可以分治!
对于一个分治过程solve(l,r),我们用l..mid的决策点去更新mid+1..r这部分的f值,这样递归地更新的话,我们一定可以保证在递归到i点的时候,1..i-1的点都已经更新过i点的f值了。我们看到,分治的过程中,左半区(l..mid)和右半区(mid+1..r)这两个区间的作用是不同的,我们要用左半区已经更新好的f值去求出点(x,y),然后用右半区的斜率去切左半区的点集更新f值。对于左半边我们需要的只是点(x,y),右半区我们需要的只是斜率-a[i]/b[i],两部分的顺序互不影响。于是,我们在处理好左半边的东西的时候保证点集按坐标排好序,在处理右半区之前保证询问按照斜率排好序,这样相当于用一系列连续变化的直线去切一些连续点组成的凸壳,那么我们就可以简单地用一个栈来维护连续点组成的凸壳,用扫描的方法更新f值。我们一开始就排好询问的顺序,然后保证在solve之前还原左半区询问集合的顺序,这样就保证了按照原顺序得到f值;在solve之后把两部分点集归并,这样就保证了每个过程中的点集是有序的。
虽然我叙述的比较烦,但是我们看到这个分治的过程是非常优美的,对于询问我们是先排序,后还原;而对于点集我们是不断地归并,恰好是对称的过程。为什么呢?上面已经说了,因为我们对左右两个半区的需求是不一样的,于是这样就得到了两个不同的有序序列,把无序化为有序。
这样看来,分治算法取代一些复杂的数据结构是一种强有力的趋势。
splay代码:
1 #include<iostream> 2 #include<cstdio> 3 #include<algorithm> 4 #include<cmath> 5 #include<cstring> 6 #define maxn 120000 7 #define eps 1e-9 8 #define inf 1e9 9 using namespace std; 10 int fa[maxn],c[maxn][2]; 11 double f[maxn],x[maxn],y[maxn],lk[maxn],rk[maxn],a[maxn],b[maxn],rate[maxn]; 12 int n,m,rot,num; 13 14 inline double fabs(double x) 15 { 16 return (x>0)?x:-x; 17 } 18 19 inline void zigzag(int x,int &rot) 20 { 21 int y=fa[x],z=fa[y]; 22 int p=(c[y][1]==x),q=p^1; 23 if (y==rot) rot=x; 24 else if (c[z][0]==y) c[z][0]=x; else c[z][1]=x; 25 fa[x]=z; fa[y]=x; fa[c[x][q]]=y; 26 c[y][p]=c[x][q]; c[x][q]=y; 27 } 28 29 inline void splay(int x,int &rot) 30 { 31 while (x!=rot) 32 { 33 int y=fa[x],z=fa[y]; 34 if (y!=rot) 35 if ((c[y][0]==x)xor(c[z][0]==y)) zigzag(x,rot); else zigzag(y,rot); 36 zigzag(x,rot); 37 } 38 } 39 40 inline void insert(int &t,int anc,int now)//加入平衡树 41 { 42 if (t==0) 43 { 44 t=now; 45 fa[t]=anc; 46 return ; 47 } 48 if (x[now]<=x[t]+eps) insert(c[t][0],t,now); 49 else insert(c[t][1],t,now); 50 } 51 52 inline double getk(int i,int j)//求斜率 53 { 54 if (fabs(x[i]-x[j])<eps) return -inf; 55 else return (y[j]-y[i])/(x[j]-x[i]); 56 } 57 58 inline int prev(int rot)//求可以和当前点组成凸包的右边第一个点 59 { 60 int t=c[rot][0],tmp=t; 61 while (t) 62 { 63 if (getk(t,rot)<=lk[t]+eps) tmp=t,t=c[t][1]; 64 else t=c[t][0]; 65 } 66 return tmp; 67 } 68 inline int succ(int rot)//求可以和当前点组成凸包的左边第一个点 69 { 70 int t=c[rot][1],tmp=t; 71 while (t) 72 { 73 if (getk(rot,t)+eps>=rk[t]) tmp=t,t=c[t][0]; 74 else t=c[t][1]; 75 } 76 return tmp; 77 } 78 79 inline void update(int t)//加入t点 80 { 81 splay(t,rot); 82 if (c[t][0])//向左求凸包 83 { 84 int left=prev(rot); 85 splay(left,c[rot][0]); c[left][1]=0; 86 lk[t]=rk[left]=getk(left,t); 87 } 88 else lk[t]=inf; 89 if (c[t][1])//向右求凸包 90 { 91 int right=succ(rot); 92 splay(right,c[rot][1]); c[right][0]=0; 93 rk[t]=lk[right]=getk(t,right); 94 } 95 else rk[t]=-inf; 96 if (lk[t]<=rk[t]+eps)//在原凸包内部的情况,直接删掉该点 97 { 98 rot=c[t][0]; c[rot][1]=c[t][1]; fa[c[t][1]]=rot; fa[rot]=0; 99 lk[rot]=rk[c[t][1]]=getk(rot,c[t][1]); 100 } 101 } 102 103 inline int find(int t,double k)//找到当前斜率的位置,即找到最优值 104 { 105 if (t==0) return 0; 106 if (lk[t]+eps>=k&&k+eps>=rk[t]) return t; 107 if (k+eps>lk[t]) return find(c[t][0],k); 108 else return find(c[t][1],k); 109 } 110 111 int main() 112 { 113 //freopen("cash.in","r",stdin); 114 //freopen("cash.out","w",stdout); 115 scanf("%d%lf",&n,&f[0]); 116 for (int i=1;i<=n;i++) scanf("%lf%lf%lf",&a[i],&b[i],&rate[i]); 117 for (int i=1;i<=n;i++) 118 { 119 int j=find(rot,-a[i]/b[i]); 120 f[i]=max(f[i-1],x[j]*a[i]+y[j]*b[i]); 121 y[i]=f[i]/(a[i]*rate[i]+b[i]); 122 x[i]=y[i]*rate[i]; 123 insert(rot,0,i); 124 update(i); 125 } 126 printf("%.3lf\n",f[n]); 127 return 0; 128 }
分治代码:
1 #include<iostream> 2 #include<cstdio> 3 #include<algorithm> 4 #include<cmath> 5 #include<cstring> 6 #define maxn 120000 7 #define eps 1e-9 8 #define inf 1e9 9 using namespace std; 10 struct query 11 { 12 double q,a,b,rate,k; 13 int pos; 14 }q[maxn],nq[maxn]; 15 double fabs(double x) 16 { 17 return (x>0)?x:-x; 18 } 19 struct point 20 { 21 double x,y; 22 friend bool operator <(const point &a,const point &b) 23 { 24 return (a.x<b.x+eps)||(fabs(a.x-b.x)<=eps&&a.y<b.y+eps); 25 } 26 }p[maxn],np[maxn]; 27 int st[maxn]; 28 double f[maxn]; 29 int n,m; 30 31 double getk(int i,int j) 32 { 33 if (i==0) return -inf; 34 if (j==0) return inf; 35 if (fabs(p[i].x-p[j].x)<=eps) return -inf; 36 return (p[i].y-p[j].y)/(p[i].x-p[j].x); 37 } 38 39 void solve(int l,int r) 40 { 41 if (l==r)//此时l之前包括l的f值已经达到最优,计算出对应的点即可 42 { 43 f[l]=max(f[l-1],f[l]); 44 p[l].y=f[l]/(q[l].a*q[l].rate+q[l].b); 45 p[l].x=p[l].y*q[l].rate; 46 return ; 47 } 48 int mid=(l+r)>>1,l1=l,l2=mid+1; 49 //对询问集合排序,1位置2斜率 50 for (int i=l;i<=r;i++) 51 if (q[i].pos<=mid) nq[l1++]=q[i]; 52 else nq[l2++]=q[i]; 53 for (int i=l;i<=r;i++) q[i]=nq[i]; 54 //递归左区间 55 solve(l,mid); 56 //左半区所有点都以计算好,把它们入栈,维护凸壳 57 int top=0; 58 for (int i=l;i<=mid;i++) 59 { 60 while (top>=2&&getk(i,st[top])+eps>getk(st[top],st[top-1])) top--; 61 st[++top]=i; 62 } 63 //拿左半区更新右半区 64 int j=1; 65 for (int i=r;i>=mid+1;i--)//保证询问斜率递减 66 { 67 while (j<top&&q[i].k<getk(st[j],st[j+1])+eps) j++; 68 f[q[i].pos]=max(f[q[i].pos],p[st[j]].x*q[i].a+p[st[j]].y*q[i].b); 69 } 70 //递归右区间 71 solve(mid+1,r); 72 //合并左右区间的点,按照x,y排序 73 l1=l,l2=mid+1; 74 for (int i=l;i<=r;i++) 75 if ((p[l1]<p[l2]||l2>r)&&l1<=mid) np[i]=p[l1++]; 76 else np[i]=p[l2++]; 77 for (int i=l;i<=r;i++) p[i]=np[i]; 78 } 79 80 bool cmp(query a,query b) 81 { 82 return a.k<b.k; 83 } 84 85 int main() 86 { 87 //freopen("cash.in","r",stdin); 88 //freopen("cash.out","w",stdout); 89 scanf("%d%lf",&n,&f[0]); 90 for (int i=1;i<=n;i++) 91 { 92 scanf("%lf%lf%lf",&q[i].a,&q[i].b,&q[i].rate); 93 q[i].k=-q[i].a/q[i].b; 94 q[i].pos=i; 95 } 96 sort(q+1,q+n+1,cmp); 97 solve(1,n); 98 printf("%.3lf\n",f[n]); 99 return 0; 100 }