AT_joisc2019_j 题解
先考虑这个式子:
\[\sum_{j=1}^{M} |C_{k_{j}} - C_{k_{j+1}}|
\]
一定是在 \(C\) 有序时取到,具体证明很简单各位读者自己证明。
那么现在式子变成:
\[\sum{V} + 2 \times({C_{\max} - C_{\min}})
\]
这个时候一个常见的技巧是将 \(C\) 排序。
这个时候就可以定义状态:
\[dp_{i,j} = \sum{V} + 2 \times (C_{j} - C_{i})
\]
然后从贪心的思想出发,\(V\) 一定是选取区间 \([i,j]\) 中最大的 \(M\) 个。
令 \(f(i,j)\) 表示区间 \([i,j]\) 中前 \(M\) 大之和,有:
\[dp_{i,j} = f(i,j) + 2 \times (C_{j} - C_{i})
\]
考虑去掉一维状态:
\[dp_{i} = \max{(f(i,j) + 2 \times C_{j})} - 2 \times C_i
\]
因为 \(f(i,j) + 2 \times C_{j}\) 满足四边形不等式,所以 \(dp_{i}\) 满足决策单调性,考虑分治优化,\(f(i,j)\) 可以直接用主席树求解。
那么我们就 \(O(n \log^2 n)\) 地做完了。
#include<bits/stdc++.h>
#define int long long
using namespace std;
//dp[l][r] 表示区间 [l,r] 内前 M 大
//求 max(dp[l][r])
//max_{r>l}(dp[l][r])
//设计状态 dp[i] 表示区间 [i,j] 前 M 大之和减去 2*(c[j]-c[i])
//dp[i]=max{f(i,j)-2*c[j]}+2*c[i]
const int maxn = 2e5+114;
const int top = 1e9+114;
const int inf = 2e18+114514;
int dp[maxn];
int n,m;
struct Node{
int sum,ls,rs;
int val;
}tr[maxn*35];
struct hhx{
int c,v;
}a[maxn];
int root[maxn],tot;
inline int kth(int lt,int rt,int L,int R,int k){
if(lt==rt){
return k*lt;
}
int mid=(lt+rt)>>1;
if((tr[tr[R].rs].sum-tr[tr[L].rs].sum)>=k){
return kth(mid+1,rt,tr[L].rs,tr[R].rs,k);
}
else{
return (tr[tr[R].rs].val-tr[tr[L].rs].val)+kth(lt,mid,tr[L].ls,tr[R].ls,k-(tr[tr[R].rs].sum-tr[tr[L].rs].sum));
}
}
inline void add(int cur,int lst,int lt,int rt,int pos){
tr[cur].sum=tr[lst].sum+1;
tr[cur].val=tr[lst].val+pos;
if(lt==rt){
return ;
}
int mid=(lt+rt)>>1;
if(pos<=mid){
tr[cur].rs=tr[lst].rs;
tr[cur].ls=++tot;
add(tr[cur].ls,tr[lst].ls,lt,mid,pos);
}
else{
tr[cur].ls=tr[lst].ls;
tr[cur].rs=++tot;
add(tr[cur].rs,tr[lst].rs,mid+1,rt,pos);
}
}
//前 k 大之和
void solve(int l,int r,int L,int R){
if(l>r) return ;
int mid=(l+r)>>1;
int mx=-inf,p=0;
for(int i=max(mid+m-1,L);i<=R;i++){
if(kth(1,top,root[mid-1],root[i],m)-2*a[i].c>mx){
mx=kth(1,top,root[mid-1],root[i],m)-2*a[i].c;
p=i;
}
}
dp[mid]=mx+2*a[mid].c;
solve(l,mid-1,L,p);
solve(mid+1,r,p,R);
}
bool cmp(hhx A,hhx B){
return A.c<B.c;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin>>n>>m;
for(int i=1;i<=n;i++) dp[i]=-inf;
for(int i=1;i<=n;i++) cin>>a[i].v>>a[i].c;
sort(a+1,a+n+1,cmp);
for(int i=1;i<=n;i++){
root[i]=++tot;
add(root[i],root[i-1],1,top,a[i].v);
}
solve(1,n-m+1,1,n);
int ans=-inf;
for(int i=1;i<=n;i++) ans=max(ans,dp[i]);
cout<<ans;
}