P4983 忘情 题解

题目链接:忘情

大概是本题为数不多的李超树解法,凸包显然太经典,不再赘述。

有道差不多的弱化版题:Atcoder Educational DP Contest 题解\(z\) 题,也就是最后一题,差不多的区间 dp 使用凸优化斜率 dp,不过那题我也是李超树写的,比较喜欢李超树。凸优化参照它人文章就行了,具体的其实就是求一个凸包,Andrew 算法 比较常用,就单调栈里维护凸包,每次取栈顶的两个点,我们知道单调栈存的是最近的两个点,然后和当前点组成三个点,两个矢量,直接根据叉乘结果决定是否满足上下凸包,不满足弹栈即可,每次再将当前点加入单调栈中,最后单调栈中保存的就是凸包了,并且也可以边求边算 dp,求完,dp 也跟着算完了。


先考虑简化原式,看看一个区间段的 \(val\) 这一大坨简化后是啥:

\[\dfrac{((\sum_{i=1}^{n}x_i\times \overline{x})+\overline{x})^2}{\overline{x}^2},观察到分母有个\ \overline{x}^2,我们分子括号内提一个公因子\ \overline{x} \]

\[分母=\overline{x}^2\times (\sum_{i=1}^{n}x_i+1)^2,约分得到原式=(\sum_{i=1}^{n}x_i +1)^2 \]

\[老规矩,区间和改用前缀和转化为变量少的形式,原式=(pre[r]-pre[l]+1)^2 \]

不考虑 \(m\) 限制,又是经典的分段最优,我们写出它的区间 dp 方程:

\[dp[r]=\min{dp[l]+(pre[r]-pre[l]+1)^2} \ (l即为上一段终点,所以不是\ l-1) \]

\[观察\ dp[l]+(pre[r]-pre[l]+1)^2\ 这个式子特点 \]

这个是个凸包,感性证明下,\(pre[l]\) 是单调递增函数,所以括号内的是个单减函数,而 \(dp[l]\) 表示前 \(l\) 个数的最优答案,随着数的增加,dp 是单调不降的,因为原数都是正数。
所以每增加一个数,至少可以让原来最优方案的最后一段的 \((pre[r]-pre[l]+1)^2\) 增大,所以 dp 一定是单调不降函数。注意到 dp 方程本身相当于由后者反复累计的,其实可以看做一个后者二元函数和,减去一个后者,后者这个二元函数图像固定一个,其实就是一个抛物线了。要么保持保持单调,要么呈现凸包形式,但显然我们将单调函数视为不完全凸包,或者视为凸包的一部分,显然这个转移是一个显而易见的凸包优化 dp 问题。但注意到题目还有个限制 恰好 \(m\) 段限制。这个问题就犯难了,我们不会还得加一个维度分割为 \(j\) 段最优的维度吧。

关于 wqs 二分,这是一个经典的 wqs 二分模型,一般来说看到分为 恰好 \(m\) 段这个字眼加上没有这个限制的条件下是一个凸包问题就很明显了。一般这种带区间段限制选择的最好用的就是 wqs 二分,它的大前提,若干个物品中恰好选择 \(m\) 个中取最优,且转移方程为凸包问题。

具体的二分斜率 \(k\),把每个段当做物品,它的转移价值即为 \((pre[r]-pre[l]+1)^2-k\),而 \(g(x)\) 表示取 \(x\) 个数的最优,显然最终 \(g(m)=dp[n]\),且此时的 \(dp[n]\) 恰好取得 \(m\) 段,答案即为 \(g(m)-k\times m=\) 即为 \(dp[n]-k\times m\),这是一个经典的套路,建议反复学会。

最后我们来详细说说下怎么写,首先 dp 方程当然不能直接这样写了,写一下转移,这里我稍微提一点,对于 wqs 二分,常常的我们二分的 check 主要是关于数量限制的,但此时常常可以同时算出最优答案了,所以嘛,我们 dp 方程同时维护一下 \(cnt\) 数量的转移,最优的 \(l\) 是多少段,然后在它基础上 \(+1\) 就能实现数量 check 和答案计算了。当然暴力转移方程不行,我们考虑斜率优化。

先打开那个平方式子化简下看看:

\[pre[r]^2+pre[l]^2-2\times pre[r]\times pre[l]+1+2\times(pre[r]-pre[l]) \]

顺便把 \(v\) 加上其中 \(pre[r]^2+2\times pre[r]+1-v\) 显然是一个当前常数不用考虑,这是因为 \(v\)\(k\),而这是一个下凸包求最小值问题,所以我们的 \(k\) 为负数,则 \(g(x)-kx\) 每一个段也即物品去掉 \(k\),即 \(val-v\) 才是当前的物品值,我们来观察剩余式子:

\[-2\times pre[l]\times pre[r]+pre[l]^2-2\times pre[l] \]

\[对于\ l\ 来说的\ k=-2\times pre[l],b=pre[l]^2-2\times pre[l] \]

还没完,还有个 \(dp[l]\) 显然在此时为定值,\(b=dp[l]+pre[l]^2-2\times pre[l]\),这个就是一个经典斜率优化了,随便你咋优化,可以凸包二分/三分、半平面交、李超线段树....我喜爱李超树,这里维护好 \(kx+b\) 的同时,再维护一个 \(cnt\) 表示这个最优线段表示的分段长,方便我们知道当前最优分段长方便 check 即可。

注意到 \(x\)\(pre[r]\) 最坏为 \(1e8\),懒得离散化,所以我们直接动态开点的李超树即可。然后注意要插入一个第 \(0\) 段,\(k\) 显然为 \(0\)\(b\) 的话显然设为 \(dp[0]=0\) 这样就能查出来 \((0,0)\) 作为使用了,我们可以用第 \(n+1\) 编号线段表示。顺便说下,二分时的斜率也意味着偏值,物品新增的权值,而本题要求找最小值,那显然就是一个下凸包,想想斜率越大的话,切点的 \(x\) 显然也越小,所以当 check 满足 \(\le m\) 时,随便看一边,比如左边,由于斜率 \(k\) 是负数。

如图所示,假如 \(x_2 \ge m\),那么我们应该二分到 \(x_1\) 判断是否满足 \(\le m\),而从蓝色斜率到红色斜率显然是 \(k\) 变小了,接近负无穷了,所以 \(check\) 成功,应该从 \(r\) 处移动收敛。顺便的 \(dp[1]=(pre[1]+1)^2+k\)。如果是右半部分就恰巧相反。如果 \(min\) 有多个相同的,但他们的 \(cnt\) 不同怎么整,假如这当中恰好有一种就是唯一的分段方式,那么我们如果 check 不成立,这种方式就无论如何都选不到了,所以 check 需要成功,那么我们尽量让 \(cnt\) 小即可。

最后的细节,二分的斜率开大一点,然后 \(check\) 满足的话让 \(k\) 变小就行了。当然离散化可以降低李超树常数,这里就写动态开点版本进行演示。

参照代码
#include <bits/stdc++.h>

// #pragma GCC optimize(2)
// #pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math")
// #pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,popcnt,tune=native")

// #define isPbdsFile

#ifdef isPbdsFile

#include <bits/extc++.h>

#else

#include <ext/pb_ds/priority_queue.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/trie_policy.hpp>
#include <ext/pb_ds/tag_and_trait.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/list_update_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/exception.hpp>
#include <ext/rope>

#endif

using namespace std;
using namespace __gnu_cxx;
using namespace __gnu_pbds;
typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef tuple<int, int, int> tii;
typedef tuple<ll, ll, ll> tll;
typedef unsigned int ui;
typedef unsigned long long ull;
typedef __int128 i128;
#define hash1 unordered_map
#define hash2 gp_hash_table
#define hash3 cc_hash_table
#define stdHeap std::priority_queue
#define pbdsHeap __gnu_pbds::priority_queue
#define sortArr(a, n) sort(a+1,a+n+1)
#define all(v) v.begin(),v.end()
#define yes cout<<"YES"
#define no cout<<"NO"
#define Spider ios_base::sync_with_stdio(false);cin.tie(nullptr);cout.tie(nullptr);
#define MyFile freopen("..\\input.txt", "r", stdin),freopen("..\\output.txt", "w", stdout);
#define forn(i, a, b) for(int i = a; i <= b; i++)
#define forv(i, a, b) for(int i=a;i>=b;i--)
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
#define endl '\n'
//用于Miller-Rabin
[[maybe_unused]] static int Prime_Number[13] = {0, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37};

template <typename T>
int disc(T* a, int n)
{
    return unique(a + 1, a + n + 1) - (a + 1);
}

template <typename T>
T lowBit(T x)
{
    return x & -x;
}

template <typename T>
T Rand(T l, T r)
{
    static mt19937 Rand(time(nullptr));
    uniform_int_distribution<T> dis(l, r);
    return dis(Rand);
}

template <typename T1, typename T2>
T1 modt(T1 a, T2 b)
{
    return (a % b + b) % b;
}

template <typename T1, typename T2, typename T3>
T1 qPow(T1 a, T2 b, T3 c)
{
    a %= c;
    T1 ans = 1;
    for (; b; b >>= 1, (a *= a) %= c)if (b & 1)(ans *= a) %= c;
    return modt(ans, c);
}

template <typename T>
void read(T& x)
{
    x = 0;
    T sign = 1;
    char ch = getchar();
    while (!isdigit(ch))
    {
        if (ch == '-')sign = -1;
        ch = getchar();
    }
    while (isdigit(ch))
    {
        x = (x << 3) + (x << 1) + (ch ^ 48);
        ch = getchar();
    }
    x *= sign;
}

template <typename T, typename... U>
void read(T& x, U&... y)
{
    read(x);
    read(y...);
}

template <typename T>
void write(T x)
{
    if (typeid(x) == typeid(char))return;
    if (x < 0)x = -x, putchar('-');
    if (x > 9)write(x / 10);
    putchar(x % 10 ^ 48);
}

template <typename C, typename T, typename... U>
void write(C c, T x, U... y)
{
    write(x), putchar(c);
    write(c, y...);
}


template <typename T11, typename T22, typename T33>
struct T3
{
    T11 one;
    T22 tow;
    T33 three;

    bool operator<(const T3 other) const
    {
        if (one == other.one)
        {
            if (tow == other.tow)return three < other.three;
            return tow < other.tow;
        }
        return one < other.one;
    }

    T3() { one = tow = three = 0; }

    T3(T11 one, T22 tow, T33 three) : one(one), tow(tow), three(three)
    {
    }
};

template <typename T1, typename T2>
void uMax(T1& x, T2 y)
{
    if (x < y)x = y;
}

template <typename T1, typename T2>
void uMin(T1& x, T2 y)
{
    if (x > y)x = y;
}

constexpr int N = 1e5 + 10;
constexpr ll INF = 1e10 + 7;
int mx;

struct Seg
{
    ll k, b;
    int cnt;

    Seg(const ll k, const ll b, const int cnt)
        : k(k),
          b(b),
          cnt(cnt)
    {
    }

    Seg() = default;

    ll getY(const ll x) const
    {
        return k * x + b;
    }
} seg[N];

struct Node
{
    int left, right;
    int id;
} node[N << 5];

#define left(x) node[x].left
#define right(x) node[x].right
#define id(x) node[x].id
int cnt;
int root;

inline void add(int& curr, int val, const int l = 1, const int r = mx)
{
    if (!curr)curr = ++cnt;
    if (!id(curr))
    {
        id(curr) = val;
        return;
    }
    const int mid = l + r >> 1;
    if (seg[id(curr)].getY(mid) > seg[val].getY(mid))swap(id(curr), val);
    if (l == r)return;
    if (seg[id(curr)].getY(l) > seg[val].getY(l))add(left(curr), val, l, mid);
    if (seg[id(curr)].getY(r) > seg[val].getY(r))add(right(curr), val, mid + 1, r);
}

inline void merge(pll& curr, const pll& other)
{
    if (curr.first > other.first)curr = other;
    if (curr.first == other.first)uMin(curr.second, other.second);
}

inline pll query(const int curr, const int val, const int l = 1, const int r = mx)
{
    if (!curr)return pll(INF, INF);
    pll ans(seg[id(curr)].getY(val), seg[id(curr)].cnt);
    const int mid = l + r >> 1;
    if (l == r)return ans;
    if (val <= mid)merge(ans, query(left(curr), val, l, mid));
    else merge(ans, query(right(curr), val, mid + 1, r));
    return ans;
}

inline void clear()
{
    forn(i, 1, cnt)
    {
        id(i) = left(i) = right(i) = 0;
    }
    cnt = root = 0;
}

ll pre[N], prePow2[N];
int n, m;
pll dp[N];

inline bool check(const ll k)
{
    clear();
    add(root, n + 1);
    forn(i, 1, n)
    {
        const auto [dpMin,dpCnt] = query(root, pre[i]);
        dp[i].first = prePow2[i] + 2 * pre[i] + 1 + k + dpMin;
        dp[i].second = dpCnt + 1;
        seg[i] = Seg(-2 * pre[i], dp[i].first + prePow2[i] - 2 * pre[i], dp[i].second);
        add(root, i);
    }
    return dp[n].second <= m;
}

inline void solve()
{
    seg[0].b = INF;
    cin >> n >> m;
    forn(i, 1, n)cin >> pre[i], uMax(mx, pre[i] += pre[i - 1]), prePow2[i] = pre[i] * pre[i];
    ll l = -1e18, r = 1e18;
    while (l < r)
    {
        const ll mid = l + r >> 1;
        if (check(mid))r = mid;
        else l = mid + 1;
    }
    check(l), cout << dp[n].first - l * m;
}

signed int main()
{
    // MyFile
    Spider
    //------------------------------------------------------
    // clock_t start = clock();
    int test = 1;
    //    read(test);
    // cin >> test;
    forn(i, 1, test)solve();
    //    while (cin >> n, n)solve();
    //    while (cin >> test)solve();
    // clock_t end = clock();
    // cerr << "time = " << double(end - start) / CLOCKS_PER_SEC << "s" << endl;
}

\[时间复杂度显然不及凸包的,双\ \log\ 级别:\ O(n\log{pre_{max}}\log{k}) \]

posted @ 2024-02-23 00:42  Athanasy  阅读(43)  评论(0编辑  收藏  举报