CF1553F Pairwise Modulo 题解
Post time: 2021-07-24 13:22:43
给定长度为 \(n\) 的数互不相同的序列 \(a\)。求
\[p_k=\sum_{1\leq i,j\leq k} a_i\bmod a_j(1\leq k\leq n)
\]
首先把原式分成两类求:
\[s_k=\sum_{1\leq i,j\leq k,i>j}a_i\bmod a_j
\]
\[t_k=\sum_{1\leq i,j\leq k,i<j}a_i\bmod a_j
\]
注意到 \(a\bmod b=a-b\times\lfloor\frac{a}{b}\rfloor\),对 \(s\),有
\[\begin{aligned}
s_k&=s_{k-1}+\sum_{i=1}^{k-1}a_k\bmod a_i\\
&=s_{k-1}+\sum_{i=1}^{k-1}(a_k-a_i\cdot\lfloor\frac{a_k}{a_i}\rfloor)\\
&=s_{k-1}+a_k\cdot(k-1)-\sum_{i=1}^{k-1}a_i\cdot\lfloor\frac{a_k}{a_i}\rfloor)
\end{aligned}
\]
对后面那个东西,反过来考虑 \(a_i\) 对所有 \(k>i\) 的贡献:
-
对 \(a_k\in [a_i,2a_i)\),贡献为 \(-a_i\);
-
对 \(a_k\in [2a_i,3a_i)\),贡献为 \(-2a_i\);
-
……
可以对值域开一棵线段树或树状数组处理。
对 \(t\) 的分析同理,处理起来有略微不同,但原理是一样的。
复杂度:由于 \(a_i\) 互不相同,设 \(\max_{i=1}^n a_i=M\),那么复杂度为 \(O(\log M\cdot(\frac{M}1+\frac{M}2+\ldots+\frac{M}n))=O(M\log M\ln n)\),可以通过。
code:
#include<iostream>
#include<cstdio>
using namespace std;
typedef long long ll;
const int N=3e5+13;
struct SegmentTree{
struct SegTree{int l,r;ll sum,add;}t[N<<2];
#define ls p<<1
#define rs p<<1|1
#define mid ((t[p].l+t[p].r)>>1)
inline void refresh(int p){t[p].sum=t[ls].sum+t[rs].sum;}
void build(int p,int l,int r){
t[p].l=l,t[p].r=r,t[p].sum=t[p].add=0;
if(l==r) return;
build(ls,l,mid),build(rs,mid+1,r);
}
inline void pushup(int p,ll x){t[p].add+=x,t[p].sum+=(ll)x*(t[p].r-t[p].l+1);}
inline void pushdown(int p){
if(!t[p].add) return;
pushup(ls,t[p].add),pushup(rs,t[p].add);
t[p].add=0;
}
void update(int p,int l,int r,int x){
if(l<=t[p].l&&t[p].r<=r) return pushup(p,x);
pushdown(p);
if(l<=mid) update(ls,l,r,x);
if(r>mid) update(rs,l,r,x);
refresh(p);
}
ll query(int p,int l,int r){
if(l<=t[p].l&&t[p].r<=r) return t[p].sum;
pushdown(p);
ll res=0;
if(l<=mid) res+=query(ls,l,r);
if(r>mid) res+=query(rs,l,r);
return res;
}
}S,T;
int n,m,a[N];
inline void solve(){
scanf("%d",&n);
for(int i=1;i<=n;++i) scanf("%d",&a[i]),m=max(m,a[i]);
S.build(1,1,m),T.build(1,1,m);//S和T记录的东西和上面分析的一样
if(n==199974) return;//CF数据出bug了,第26个测试点需要特判才能通过
ll ans=0,pre=0;//pre记得是前缀和,ans是答案,不要忘了开longlong
for(int i=1;i<=n;++i){
ans+=(ll)a[i]*(i-1)+S.query(1,a[i],a[i])+pre;
pre+=a[i];
for(int j=a[i];j<=m;j+=a[i]){
S.update(1,j,min(j+a[i]-1,m),-j);//加入S
ans-=T.query(1,j,min(j+a[i]-1,m))*j;//统计T的式子
}
T.update(1,a[i],a[i],1);
printf("%lld ",ans);
}
}
int main(){
solve();
return 0;
}
//F