CF1030F Putting Boxes Together
昨晚的比赛题。(像我这种蒟蒻只能打打div2)
题意
给你$n$个物品,每一个物品$i$,有一个权值$w_i$和一个位置$a_i$,定义移动一个物品$i$到位置$t$的代价为$w_i * \left |a_i - t \right |$,要求你写一个数据结构支持以下两种操作:
1、修改一个物品的权值
2、查询把一个区间内全部移到相邻的位置的最小值。
举个栗子:如果要把$[l, r]$移到相邻的位置,就是对于$\forall i \in [l, r]$,要有$pos_i = x + i - l\ (1 \leq x \leq n - (r - l))$,然后要确定这个$x$使移动的总代价最小,最后要求这个最小的代价对$1e9 + 7$取模的结果,每次询问独立。
注意:要先使总代价最小然后再取模,而不是取模后最小。
保证给出的$a_i$递增。
两个原题:
一个简单题:
我们有很经典的货仓选址的模型,就是在直线上有$n$个点,每一个点$i$有一个位置$pos_i$,每一个点有一个货物。定义运输货物的代价是移动的距离。现在要在直线上选择一个点建立货仓,要把所有的货物都运到这个点,要求使使代价最小,求这个最小代价。
很简单吧,中位数。
抄一段lyd书上的证明:先把所有的点按照$pos_i$排序,假设货仓建在$X$,左侧的点有$P$个,右侧的点有$Q$个。如果$P < Q$,那么把$X$往右移动会使答案变优,同理当$P > Q$使把$X$向左移动会使答案变优,所有最优解会在$P == Q$的地方产生。
再抄一句:当$n$是偶数的时候,这时$pos_{\frac{n}{2}}$和$pos_{\frac{n + 1}{2}}$中的点都可以是最优解。
稍微强化板:
现在每一个货仓$i$里有$w_i$个货物。
排序后,找到第一个$X$,使$\sum_{i = 1}^{X}a_i \geq \sum_{i = X + 1}^{n}a_i$,$X$就是最优解。
这个东西叫做带权中位数。
丢一个百度百科的链接,里面有证明。 传送门
我自己把不严谨的证明在这里再写一遍:
假设最优答案在$T$取到,那么有(唔,这里$a_i$代表权值):
$\sum_{i = 1}^{n}a_i *dis(i, T) \leq \sum_{i = 1}^{n}a_i * dis(i, T + 1)$
变形一下:
$\sum_{i = 1}^{T - 1}a_i *dis(i, T) + \sum_{i = T + 1}^{n}a_i *dis(i, T) + a_{T + 1} * dis(T, T + 1)\leq \sum_{i = 1}^{T}a_i *dis(i, T + 1) + \sum_{i = T + 2}^{n}a_i *dis(i, T + 1) + a_T * dis(T, T + 1)$
发现$T$左边的点走到$T + 1$与走到$T$比,多走了$dis(T, T + 1)$,而右边的点则少走了$dis(T, T + 1)$。
消掉一模一样的东西就得到了: $\sum_{i = 1}^{T}a_i \geq \sum_{i = T + 1}^{n}a_i$。
把$T$和$T - 1$代进去也是一样的结果。
回到这题
那么这题要求移到相邻的位置,可以理解为先移到同一个位置然后移回来,相对移动不变,我们只要找到这个带权中位数的位置,就能得到最优解了。
带上修改,我们可以用两个树状数组来维护,一个维护$\sum_{i = 1}^{n}w_i$,另一个维护$\sum_{i = 1}^{n}w_i*(a_i - i)$,询问的时候先二分一下找到带权中位数的位置$pos$,然后对于$pos$左边的点向右移,对于$pos$右边的点向左移,就可以计算出答案了。
时间复杂度$O(nlog^2{n})$。
Code:
#include <cstdio> #include <cstring> using namespace std; typedef long long ll; const int N = 2e5 + 5; const ll P = 1e9 + 7; int n, qn; ll a[N], w[N]; template <typename T> inline void read(T &X) { X = 0; char ch = 0; T op = 1; for(; ch > '9'|| ch < '0'; ch = getchar()) if(ch == '-') op = -1; for(; ch >= '0' && ch <= '9'; ch = getchar()) X = (X << 3) + (X << 1) + ch - 48; X *= op; } namespace BitSum { ll s[N]; #define lowbit(p) (p & (-p)) inline void modify(int p, ll v) { for(; p <= n; p += lowbit(p)) s[p] += v; } inline ll query(int p) { ll res = 0LL; for(; p > 0; p -= lowbit(p)) res += s[p]; return res; } inline ll getSum(int l, int r) { if(r < l) return 0LL; return query(r) - query(l - 1); } } namespace BitMul { ll s[N]; #define lowbit(p) (p & (-p)) inline void modify(int p, ll v) { v %= P; for(; p <= n; p += lowbit(p)) (s[p] = s[p] + v + P) %= P; } inline ll query(int p) { ll res = 0LL; for(; p > 0; p -= lowbit(p)) (res += s[p]) %= P; return res; } inline ll getSum(int l, int r) { return (query(r) - query(l - 1) + P) % P; } } inline int getPos(int x, int y) { int ln = x, rn = y, mid, res; for(; ln <= rn; ) { mid = (ln + rn) / 2; if(BitSum :: getSum(x, mid) >= BitSum :: getSum(mid + 1, y)) res = mid, rn = mid - 1; else ln = mid + 1; } return res; } inline ll abs(ll x) { return x > 0 ? x : -x; } inline ll max(ll x, ll y) { return x > y ? x : y; } inline ll min(ll x, ll y) { return x > y ? y : x; } inline void solve(int x, int y) { if(x == y) { puts("0"); return; } int pos = getPos(x, y); /* ll res = BitMul :: getSum(x, y); //d1 = 0LL, d2 = 0LL; d1 = (d1 - (BitSum :: getSum(pos, y) % P) * 1LL * abs(a[pos] - pos) % P + P) % P; d1 = (d1 + (BitSum :: getSum(x, pos - 1) % P) * 1LL * abs(a[pos] - pos) % P + P) % P; d2 = (d2 - (BitSum :: getSum(pos + 1, y) % P) * 1LL * abs(a[pos] - pos) % P + P) % P; d2 = (d2 + (BitSum :: getSum(x, pos) % P) * 1LL * abs(a[pos] - pos) % P + P) % P; ll d = 0LL; d = (d - (BitSum :: getSum(pos, y) % P) * 1LL * abs(a[pos] - pos) % P + P) % P; d = (d + (BitSum :: getSum(x, pos - 1) % P) * 1LL * abs(a[pos] - pos) % P + P) % P; */ ll res = 0LL; res = (-BitMul :: getSum(x, pos) + (BitSum :: getSum(x, pos) % P) * abs(a[pos] - pos) % P + P) % P; res = (res - (BitSum :: getSum(pos, y)) % P * abs(a[pos] - pos) % P + BitMul :: getSum(pos, y) + P) % P; // printf("%lld\n", (res + d + P) % P); printf("%lld\n", res); } int main() { read(n), read(qn); for(int i = 1; i <= n; i++) read(a[i]); for(int i = 1; i <= n; i++) { read(w[i]); BitMul :: modify(i, w[i] * (a[i] - i)); BitSum :: modify(i, w[i]); } for(int x, y; qn--; ) { read(x), read(y); if(x < 0) { x = -x; BitSum :: modify(x, -w[x]); BitMul :: modify(x, -1LL * w[x] * (a[x] - x)); w[x] = 1LL * y; BitSum :: modify(x, w[x]); BitMul :: modify(x, 1LL * w[x] * (a[x] - x)); } else solve(x, y); } /* for(int i = 1; i <= n; i++) printf("%lld ", BitSum :: getSum(i, i)); */ return 0; }