Codeforces 1701F. Points (2500)
给定两个正整数 \(q,d\),定义三元组 \((i,j,k)\) 满足 \(i<j<k\bigcap k-i\le d\) 为美丽三元组,现在有一个空集和 \(q\) 组询问,每次给定一个正整数 \(x\),若 \(x\) 不在集合,那么将 \(x\) 加入集合,若 \(x\) 在集合中,那么将 \(x\) 从集合中删除,每次询问计算集合中美丽三元组的个数。
\(1\le q,d,x\le 2\cdot 10^5\)。
考虑对每个三元组以 \(i\) 计数,令 \([i+1,i+d]\) 中出现的数的个数为 \(k_i\),则对答案的贡献为 \(\binom{k_i}{2}\)。
那么考虑添加\(/\)删除对答案以及 \(k_i\) 的影响。
添加 \(x\) 后,对答案的贡献一种是作为 \(i\),还有一种是作为 \(j,k(\)可以一起统计\()\)。作为 \(i\) 的贡献显然是 \(\binom{k_x}{2}\),作为 \(j,k\) 的贡献是 \(\sum\limits_{i=x-d}^{x-1} k_i\),显然可以用线段树维护 \(k_i\) 的区间和。
对 \(k_i\) 的影响则是 \(\forall i\in [x-d,x-1],k_i\leftarrow k_i-1\),也就是区间减操作。
同理,删除也可用线段树维护。
总时间复杂度 \(O(q\log V)\),\(V\) 是值域大小。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int V=2e5+5;
int n,d,x,add[V<<2];bool vis[V];
ll ans,sum[V<<2],psum[V<<2],insum[V<<2];
inline void pushup(int x){
psum[x]=psum[x<<1]+psum[x<<1|1];
sum[x]=sum[x<<1]+sum[x<<1|1];
insum[x]=insum[x<<1]+insum[x<<1|1];
}
inline void pushdown(int x,int l,int r){
if(!add[x])return;int mid=l+r>>1,v=add[x];
add[x<<1]+=v,add[x<<1|1]+=v;
sum[x<<1]+=(ll)v*psum[x<<1],sum[x<<1|1]+=(ll)v*psum[x<<1|1];
insum[x<<1]+=(ll)v*(mid-l+1),insum[x<<1|1]+=(ll)v*(r-mid);
add[x]=0;
}
inline void update(int x,int l,int r,int L,int R,int v){
if(L>R)return;
if(L<=l&&R>=r){add[x]+=v,sum[x]+=(ll)v*psum[x],insum[x]+=(ll)v*(r-l+1);return;}
int mid=l+r>>1;pushdown(x,l,r);
if(L<=mid)update(x<<1,l,mid,L,R,v);
if(R>mid)update(x<<1|1,mid+1,r,L,R,v);
pushup(x);
}
inline void modify(int x,int l,int r,int y,int v){
if(l==r){psum[x]+=v,sum[x]=v>0?insum[x]:0;return;}
int mid=l+r>>1;pushdown(x,l,r);
y<=mid?modify(x<<1,l,mid,y,v):modify(x<<1|1,mid+1,r,y,v);
pushup(x);
}
inline int queryN(int x,int l,int r,int L,int R){
if(L>R)return 0;
if(L<=l&&R>=r)return psum[x];
int mid=l+r>>1,ans=0;pushdown(x,l,r);
if(L<=mid)ans=queryN(x<<1,l,mid,L,R);
if(R>mid)ans+=queryN(x<<1|1,mid+1,r,L,R);
return ans;
}
inline ll queryV(int x,int l,int r,int L,int R){
if(L>R)return 0;
if(L<=l&&R>=r)return sum[x];
int mid=l+r>>1;ll ans=0;pushdown(x,l,r);
if(L<=mid)ans=queryV(x<<1,l,mid,L,R);
if(R>mid)ans+=queryV(x<<1|1,mid+1,r,L,R);
return ans;
}
int main(){
scanf("%d%d",&n,&d);
for(int i=1;i<=n;++i){
scanf("%d",&x);
if(!vis[x]){
ans+=queryV(1,1,V,max(1,x-d),x-1);
update(1,1,V,max(1,x-d),x-1,1);
int num=queryN(1,1,V,x+1,min(x+d,V));
ans+=(ll)num*(num-1)/2;
modify(1,1,V,x,1);
}else{
update(1,1,V,max(1,x-d),x-1,-1);
ans-=queryV(1,1,V,max(1,x-d),x-1);
int num=queryN(1,1,V,x+1,min(x+d,V));
ans-=(ll)num*(num-1)/2;
modify(1,1,V,x,-1);
}
vis[x]^=1,printf("%lld\n",ans);
}
return 0;
}