[原题描述以及提交地址]:http://acm.tongji.edu.cn/problem?pid=10011
[题目大意]
给定两个长度为N的序列,要给这两个序列的数连线。连线只能在两个序列之间进行,且连线不能交叉,每个数最多只能选一次。连线从左到右进行,每次连线收益为这两个数的乘积。对于两个序列,都有:每段连续的没被选中的数的和的平方为损失。
防剧透
防剧透
防剧透
防剧透
防剧透
防剧透
防剧透
防剧透
防剧透
防剧透
防剧透
防剧透
防剧透
[解题思路]
O(n^4):
f[i][j]代表a序列前i个数,b序列前j个数中i,j必选所得到的最优收益。
f[i][j] = a[i] * b[j] + max(f[k][l] - (suma[i - 1] - suma[k])^2 - (sumb[j - 1] - sumb[l])^2) {0 < k < i, 0 < l < j}
===================================================================================
O(n^3):
可以发现对于k + 1...i 以及 l + 1...j 这两段数之间可以再连线,而且答案不会更劣。
于是有k == i - 1 or l == j - 1
f[i][j] = a[i] * b[j] + max(f[k][j - 1] - (suma[i - 1] - suma[k])^2,f[i - 1][l] - (sumb[j - 1] - sumb[l])^2) {0 < k < i, 0 < l < j}
===================================================================================
O(n^2):
事实上以上的方程是可以用斜率优化的。只不过是同时依赖于两个斜率优化方程而已。于是,对于每个i,j开一个单调队列,维护即可。
===================================================================================
Postscript:打斜率优化的时候一定要注意等号,而且最好从凸包的角度来理解,来实现,比较不容易出错。
#include <cstdio> #include <algorithm> #include <deque> const int N = 1000 + 9; typedef long long ll; int n,a[N],b[N],i,j,t; ll suma[N],sumb[N],f[N][N]; std::deque<int> qi[N],qj[N]; inline ll sqr(const ll x){return x*x;} inline ll calci(const int x) {return f[i - 1][x] - sqr(sumb[j - 1] - sumb[x]);} inline ll calcj(const int x) {return f[x][j - 1] - sqr(suma[i - 1] - suma[x]);} inline ll Xi(const int k,const int l) {return f[i - 1][k] - sqr(sumb[k]) - (f[i - 1][l] - sqr(sumb[l]));} inline ll Yi(const int k,const int l) {return sumb[l] - sumb[k];} inline ll Xj(const int k,const int l) {return f[k][j - 1] - sqr(suma[k]) - (f[l][j - 1] - sqr(suma[l]));} inline ll Yj(const int k,const int l) {return suma[l] - suma[k];} int main() { #ifndef ONLINE_JUDGE freopen("sxbk.in","r",stdin); freopen("sxbk.out","w",stdout); #endif scanf("%d",&n); for (i = 1; i <= n; ++i) { scanf("%d",a+i); suma[i] = suma[i - 1] + a[i]; } for (i = 1; i <= n; ++i) { scanf("%d",b+i); sumb[i] = sumb[i - 1] + b[i]; } for (i = 1; i <= n; ++i) { for (j = 1; j <= n; ++j) { while (qi[i - 1].size() > 1 && calci(qi[i - 1].front()) <= calci(qi[i - 1][1])) qi[i - 1].pop_front(); while (qj[j - 1].size() > 1 && calcj(qj[j - 1].front()) <= calcj(qj[j - 1][1])) qj[j - 1].pop_front(); f[i][j] = - sqr(suma[i - 1]) - sqr(sumb[j - 1]); if ((i - 1) && qi[i - 1].size()) f[i][j] = std::max(f[i][j],calci(qi[i - 1].front())); if ((j - 1) && qj[j - 1].size()) f[i][j] = std::max(f[i][j],calcj(qj[j - 1].front())); f[i][j] += a[i] * b[j]; while ((t = qi[i - 1].size()) > 1 && Xi(qi[i - 1][t - 2],qi[i - 1].back()) * Yi(qi[i - 1].back(),j) >= Xi(qi[i - 1].back(),j) * Yi(qi[i - 1][t - 2],qi[i - 1].back())) qi[i - 1].pop_back(); while ((t = qj[j - 1].size()) > 1 && Xj(qj[j - 1][t - 2],qj[j - 1].back()) * Yj(qj[j - 1].back(),i) >= Xj(qj[j - 1].back(),i) * Yj(qj[j - 1][t - 2],qj[j - 1].back())) qj[j - 1].pop_back(); if (i - 1) qi[i - 1].push_back(j); if (j - 1) qj[j - 1].push_back(i); } } ll ans = -0x7fffffff; for (int i = 1; i <= n; ++i) ans = std::max(ans,std::max(f[i][n] - sqr(suma[n] - suma[i]),f[n][i] - sqr(sumb[n] - sumb[i]))); printf("%I64d\n",ans); }