CF1030F Putting Boxes Together

昨晚的比赛题。(像我这种蒟蒻只能打打div2)

题意

  给你$n$个物品,每一个物品$i$,有一个权值$w_i$和一个位置$a_i$,定义移动一个物品$i$到位置$t$的代价为$w_i * \left |a_i - t  \right |$,要求你写一个数据结构支持以下两种操作:

    1、修改一个物品的权值

    2、查询把一个区间内全部移到相邻的位置的最小值。

  举个栗子:如果要把$[l, r]$移到相邻的位置,就是对于$\forall i \in [l, r]$,要有$pos_i = x + i - l\ (1 \leq x \leq n - (r - l))$,然后要确定这个$x$使移动的总代价最小,最后要求这个最小的代价对$1e9 + 7$取模的结果,每次询问独立。

  注意:要先使总代价最小然后再取模,而不是取模后最小。

  保证给出的$a_i$递增。

两个原题:

  一个简单题:

    我们有很经典的货仓选址的模型,就是在直线上有$n$个点,每一个点$i$有一个位置$pos_i$,每一个点有一个货物。定义运输货物的代价是移动的距离。现在要在直线上选择一个点建立货仓,要把所有的货物都运到这个点,要求使使代价最小,求这个最小代价。

    很简单吧,中位数。

    抄一段lyd书上的证明:先把所有的点按照$pos_i$排序,假设货仓建在$X$,左侧的点有$P$个,右侧的点有$Q$个。如果$P < Q$,那么把$X$往右移动会使答案变优,同理当$P > Q$使把$X$向左移动会使答案变优,所有最优解会在$P == Q$的地方产生。

    再抄一句:当$n$是偶数的时候,这时$pos_{\frac{n}{2}}$和$pos_{\frac{n + 1}{2}}$中的点都可以是最优解。

  稍微强化板:

    现在每一个货仓$i$里有$w_i$个货物。

    排序后,找到第一个$X$,使$\sum_{i = 1}^{X}a_i \geq \sum_{i = X + 1}^{n}a_i$,$X$就是最优解。

    这个东西叫做带权中位数

    丢一个百度百科的链接,里面有证明。                   传送门

    我自己把不严谨的证明在这里再写一遍:

      假设最优答案在$T$取到,那么有(唔,这里$a_i$代表权值):

          $\sum_{i = 1}^{n}a_i  *dis(i, T) \leq \sum_{i = 1}^{n}a_i * dis(i, T + 1)$

      变形一下:

          $\sum_{i = 1}^{T - 1}a_i  *dis(i, T) + \sum_{i = T + 1}^{n}a_i  *dis(i, T)  + a_{T + 1} * dis(T, T + 1)\leq \sum_{i = 1}^{T}a_i  *dis(i, T + 1) + \sum_{i = T + 2}^{n}a_i  *dis(i, T + 1)  + a_T * dis(T, T + 1)$

      

      发现$T$左边的点走到$T + 1$与走到$T$比,多走了$dis(T, T + 1)$,而右边的点则少走了$dis(T, T + 1)$。

      消掉一模一样的东西就得到了: $\sum_{i = 1}^{T}a_i \geq \sum_{i = T + 1}^{n}a_i$。

      把$T$和$T - 1$代进去也是一样的结果。

回到这题

  那么这题要求移到相邻的位置,可以理解为先移到同一个位置然后移回来,相对移动不变,我们只要找到这个带权中位数的位置,就能得到最优解了。

  带上修改,我们可以用两个树状数组来维护,一个维护$\sum_{i = 1}^{n}w_i$,另一个维护$\sum_{i = 1}^{n}w_i*(a_i - i)$,询问的时候先二分一下找到带权中位数的位置$pos$,然后对于$pos$左边的点向右移,对于$pos$右边的点向左移,就可以计算出答案了。

  时间复杂度$O(nlog^2{n})$。

Code:

#include <cstdio>
#include <cstring>
using namespace std;
typedef long long ll;

const int N = 2e5 + 5;
const ll P = 1e9 + 7;

int n, qn;
ll a[N], w[N];

template <typename T>
inline void read(T &X) {
    X = 0; char ch = 0; T op = 1;
    for(; ch > '9'|| ch < '0'; ch = getchar())
        if(ch == '-') op = -1;
    for(; ch >= '0' && ch <= '9'; ch = getchar())
        X = (X << 3) + (X << 1) + ch - 48;
    X *= op;
}

namespace BitSum {
    ll s[N];
    
    #define lowbit(p) (p & (-p))
    
    inline void modify(int p, ll v) {
        for(; p <= n; p += lowbit(p))
            s[p] += v;
    }
    
    inline ll query(int p) {
        ll res = 0LL;
        for(; p > 0; p -= lowbit(p))
            res += s[p];
        return res;
    }
    
    inline ll getSum(int l, int r) {
        if(r < l) return 0LL; 
        return query(r) - query(l - 1);
    }
    
} 

namespace BitMul {
    ll s[N];
    
    #define lowbit(p) (p & (-p))
    
    inline void modify(int p, ll v) {
        v %= P;
        for(; p <= n; p += lowbit(p))
            (s[p] = s[p] + v + P) %= P;
    }
    
    inline ll query(int p) {
        ll res = 0LL;
        for(; p > 0; p -= lowbit(p))
            (res += s[p]) %= P;
        return res;
    }
    
    inline ll getSum(int l, int r) {
        return (query(r) - query(l - 1) + P) % P;
    }
    
} 

inline int getPos(int x, int y) {
    int ln = x, rn = y, mid, res;
    for(; ln <= rn; ) {
        mid = (ln + rn) / 2;
        if(BitSum :: getSum(x, mid) >= BitSum :: getSum(mid + 1, y))
            res = mid, rn = mid - 1;
        else ln = mid + 1;
    }
    return res;
}

inline ll abs(ll x) {
    return x > 0 ? x : -x;
}

inline ll max(ll x, ll y) {
    return x > y ? x : y;
}

inline ll min(ll x, ll y) {
    return x > y ? y : x;
}

inline void solve(int x, int y) {
    if(x == y) {
        puts("0");
        return;
    }
    int pos = getPos(x, y);
/*    ll res = BitMul :: getSum(x, y); //d1 = 0LL, d2 = 0LL;
    d1 = (d1 - (BitSum :: getSum(pos, y) % P) * 1LL * abs(a[pos] - pos) % P + P) % P;
    d1 = (d1 + (BitSum :: getSum(x, pos - 1) % P) * 1LL * abs(a[pos] - pos) % P + P) % P;    
    d2 = (d2 - (BitSum :: getSum(pos + 1, y) % P) * 1LL * abs(a[pos] - pos) % P + P) % P;
    d2 = (d2 + (BitSum :: getSum(x, pos) % P) * 1LL * abs(a[pos] - pos) % P + P) % P;    
    ll d = 0LL;
    d = (d - (BitSum :: getSum(pos, y) % P) * 1LL * abs(a[pos] - pos) % P + P) % P;
    d = (d + (BitSum :: getSum(x, pos - 1) % P) * 1LL * abs(a[pos] - pos) % P + P) % P;    */
    
    ll res = 0LL;
    res = (-BitMul :: getSum(x, pos) + (BitSum :: getSum(x, pos) % P) * abs(a[pos] - pos) % P + P) % P;    
    res = (res - (BitSum :: getSum(pos, y)) % P * abs(a[pos] - pos) % P + BitMul :: getSum(pos, y) + P) % P;

//    printf("%lld\n", (res + d + P) % P);
    printf("%lld\n", res);
}

int main() {
    read(n), read(qn);
    for(int i = 1; i <= n; i++) read(a[i]);
    for(int i = 1; i <= n; i++) {
        read(w[i]);
        BitMul :: modify(i, w[i] * (a[i] - i));
        BitSum :: modify(i, w[i]);
    }
    
    for(int x, y; qn--; ) {
        read(x), read(y);
        if(x < 0) {
            x = -x;
            BitSum :: modify(x, -w[x]);
            BitMul :: modify(x, -1LL * w[x] * (a[x] - x));
            w[x] = 1LL * y;
            BitSum :: modify(x, w[x]);
            BitMul :: modify(x, 1LL * w[x] * (a[x] - x));
        } else solve(x, y);
    }
    
/*    for(int i = 1; i <= n; i++)
        printf("%lld ", BitSum :: getSum(i, i));    */
    
    return 0;
}
View Code
posted @ 2018-09-24 10:42  CzxingcHen  阅读(728)  评论(15编辑  收藏  举报