线段树+思维——cf1311F
/* 1-100t 3+2t 2+3t v<0的分一组,v>=0的分一组 */ #include<bits/stdc++.h> using namespace std; #define N 400005 #define ll long long struct Node{ ll x,v; }a[N],b[N]; int n,tot1,tot2; int cmp1(Node a,Node b){return a.x<b.x;} ll ans,Sum[N],x[N],m; #define lson l,m,rt<<1 #define rson m+1,r,rt<<1|1 ll sum[N<<2],cnt[N<<2]; void build(int l,int r,int rt){ memset(sum,0,sizeof sum); memset(cnt,0,sizeof cnt); } void update(int pos,ll v,int l,int r,int rt){ if(l==r){ cnt[rt]++;sum[rt]+=v;return; } int m=l+r>>1; if(pos<=m)update(pos,v,lson); else update(pos,v,rson); cnt[rt]=cnt[rt<<1]+cnt[rt<<1|1]; sum[rt]=sum[rt<<1]+sum[rt<<1|1]; } ll query1(int L,int R,int l,int r,int rt){ if(L<=l && R>=r)return cnt[rt]; int m=l+r>>1; ll res=0; if(L<=m)res+=query1(L,R,lson); if(R>m)res+=query1(L,R,rson); return res; } ll query2(int L,int R,int l,int r,int rt){ if(L<=l && R>=r)return sum[rt]; int m=l+r>>1; ll res=0; if(L<=m)res+=query2(L,R,lson); if(R>m)res+=query2(L,R,rson); return res; } ll t[N],tt[N]; int main(){ cin>>n; for(int i=1;i<=n;i++)cin>>t[i]; for(int i=1;i<=n;i++)cin>>tt[i]; for(int i=1;i<=n;i++){ int x,v;x=t[i];v=tt[i]; if(v<0){ tot1++; a[tot1].x=x;a[tot1].v=v; }else { tot2++; b[tot2].x=x;b[tot2].v=v; } } sort(a+1,a+1+tot1,cmp1); sort(b+1,b+1+tot2,cmp1); if(tot1 && tot2){//两组间的贡献 for(int i=1;i<=tot2;i++)Sum[i]=Sum[i-1]+b[i].x; for(int i=1;i<=tot1;i++){ if(a[i].x<b[1].x){ ans+=Sum[tot2]-a[i].x*tot2; continue; } int L=1,R=tot2,mid,pos; while(L<=R){ mid=L+R>>1; if(b[mid].x<=a[i].x) pos=mid,L=mid+1; else R=mid-1; } ans+=(Sum[tot2]-Sum[pos])-a[i].x*(tot2-pos); } } //a同组内的贡献:在a[i]左侧的且速度绝对值大的 for(int i=1;i<=tot1;i++)x[++m]=a[i].v; sort(x+1,x+1+m); m=unique(x+1,x+1+m)-x-1; build(1,m,1); for(int i=1;i<=tot1;i++){ int pos=lower_bound(x+1,x+1+m,a[i].v)-x; ll num=query1(1,pos,1,m,1); ll sum=query2(1,pos,1,m,1); ans+=num*a[i].x-sum; update(pos,a[i].x,1,m,1); } //b同组内的贡献:在b[i]左侧的且速度绝对值小的 m=0; for(int i=1;i<=tot2;i++)x[++m]=b[i].v; sort(x+1,x+1+m); m=unique(x+1,x+1+m)-x-1; build(1,m,1); for(int i=1;i<=tot2;i++){ int pos=lower_bound(x+1,x+1+m,b[i].v)-x; ll num=query1(1,pos,1,m,1); ll sum=query2(1,pos,1,m,1); ans+=num*b[i].x-sum; update(pos,b[i].x,1,m,1); } cout<<ans<<'\n'; }