HZOJ Function
比较神仙的一道dp,考试的时候还以为是打表找规律啥的。
我们重新描述一下这道题:一个10 9 × n的网格,每个格子有一个权值,每一列格子的权值都是相同的。从一个起点开始,每次可以向上走一格或者向左上角走一格,直到走到最上面一行为止,你需要最小化经过的格子的总权值。
然而我并没有看出来。
首先我们可以发现一些显然的性质,最优的路径之一一定形如:先往左上走若干步(可能不走),到达权值较小的一列后,一直往上走到顶。对于每个询问,枚举从起点出发最终会到达哪一列,就可以得到一个O(nq)的做法。
然而我也没有想到……
算了直接把作者的题解全放出来吧:
对于任意1 ≤ i ≤ j ≤ n, 从(x, j)出发最终到达第i列然后走到顶的代价,可以表示为一个关于x的一次函数,我们只关心这些一次函数的最小值,也就是这些直线形成的下凸壳。我们得到一个思路:将询问离线,按y从小到大排序,从最左边开始每次加入一条直线,维护下凸壳,然后在凸壳上二分即可得到答案。怎么维护下凸壳呢?对于一个点(x, y),它要么继承上一列x − 1的决策,要么就直接往上走到顶。并且我们发现,第二种情况只会出现在从顶端开始连续的一段中。于是我们只需要用栈维护凸壳即可。O((n + q) log n).
刚开始没怎么看懂,好像我的做法和题解也不是很一样,其实现在还有一些细节没有搞明白……
首先看暴力的式子:$ans=min(ans,sum[y(i)]-sum[j]+(x(i)-y(i)+j)*A[j])$
在y固定时,他是一个关于x的一次函数,即$y=kx+b$的形式,设走到j时停止然后向上走,那么$k=A[j],b=sum[y]-sum[j]+(j-y)*A[j]$
对于每一个j都是一条直线,那么这些直线构成了一个上凸壳。
我们可以用On的复杂度枚举y,用栈维护凸壳(添加直线是加在了坐标系的最左边),考虑y增加会给直线造成什么影响,只会使直线的截距发生改变而斜率不变,所以原来的凸壳仍然是对的。
那么考虑如何吧j=y的这条之间加入凸壳,首先将斜率大于这条直线的栈顶直线弹掉,然后交点也得是单调的,继续弹掉不合法的,(自己yy一下坐标系,横轴是询问的x,纵轴为最优解),
然后处理当前y的询问,直接二分栈找到当前x在坐标系中对应的直线就可以了(一定注意栈顶其实是坐标轴最左边的直线)。
放下代码(稍恶心):
1 #include<algorithm> 2 #include<iostream> 3 #include<cstring> 4 #include<cstdio> 5 #define st sta[top] 6 #define sm sta[mid] 7 #define sm1 sta[mid+1] 8 #define st1 sta[top-1] 9 #define int LL 10 #define LL long long 11 using namespace std; 12 struct ques 13 { 14 int x,y,id; 15 #define x(i) que[i].x 16 #define y(i) que[i].y 17 #define id(i) que[i].id 18 friend bool operator < (ques a,ques b) 19 {return a.y<b.y;} 20 }que[500010]; 21 int n,A[500010],q,maxx; 22 LL sum[500010],al[500010]; 23 LL sta[500010],top; 24 double getx(int k1,int k2,int j1,int j2){return (double)(j2-j1)/(double)(k1-k2);} 25 inline int read(); 26 signed main() 27 { 28 // freopen("function2.in","r",stdin); 29 // freopen("out.out","w",stdout); 30 31 n=read(); 32 for(int i=1;i<=n;i++)A[i]=read(),sum[i]=sum[i-1]+A[i]; 33 q=read(); 34 for(int i=1;i<=q;i++)x(i)=read(),y(i)=read(),id(i)=i; 35 sort(que+1,que+q+1); 36 37 int now=1; 38 for(int y=1;y<=n;y++) 39 { 40 while(top&&A[sta[top]]>=A[y])top--; 41 while(top>1&& 42 getx(A[y],A[st],0,sum[y]-sum[st]+A[st]*(st-y)) 43 >=getx(A[st1],A[st],sum[y]-sum[st1]+A[st1]*(st1-y),sum[y]-sum[st]+A[st]*(st-y)) 44 )top--; 45 sta[++top]=y; 46 for(;y(now)==y&&now<=q;now++) 47 { 48 int l=1,r=top,mid; 49 while(l<r) 50 { 51 mid=(l+r)>>1; 52 double tx=getx(A[sm],A[sm1],sum[y]-sum[sm]+A[sm]*(sm-y),sum[y]-sum[sm1]+A[sm1]*(sm1-y)); 53 if(x(now)<=tx)l=mid+1; 54 else r=mid; 55 } 56 mid=l; 57 al[id(now)]=sum[y]-sum[sm]+A[sm]*(x(now)-y+sm); 58 } 59 60 } 61 for(int i=1;i<=q;i++)printf("%lld\n",al[i]); 62 } 63 inline int read() 64 { 65 int s=0,f=1;char a=getchar(); 66 while(a<'0'||a>'9'){if(a=='-')f=-1;a=getchar();} 67 while(a>='0'&&a<='9'){s=s*10+a-'0';a=getchar();} 68 return s*f; 69 }