【CodeForces】671 C. Ultimate Weirdness of an Array
【题目】C. Ultimate Weirdness of an Array
【题意】给定长度为n的正整数序列,定义一个序列的价值为max(gcd(ai,aj)),1<=i<j<=n,定义f(i,j)为移除序列i~j后剩余序列的价值,求Σf(i,j)。1<=n,ai<=2*10^5。
【算法】数论+线段树
【题解】要求所有区间的f(i,j),转化为,记ans[i]表示f(l,r)=i的区间数量,则ANS=Σi*ans[i],i=1~mx,mx=max(ai)。
求解ans[i]不方便,记h[i]表示f(l,r)<=i的区间数量,则ans[i]=h[i]-h[i-1],显然h[i]单调不减。
考虑x=mx时,h[x]=n*(n+1)/2(即所有区间)。当x=mx-1时,设数列中mx的倍数有k个,就需要满足所选区间包含至少k-1个mx的倍数。随着x的减少,逐个将不满足的区间删除,就可以得到所有h[x]。
如何实现删除不满足的区间?
记v[i]表示数列中所有i的倍数的位置,用vector存储为v[i][k],这一步甚至可以用O(n√n)的分解素因数实现。
如果(l,r)合法,那么(l,R),R>r也一定合法。所以记next[i]表示L=i,R>=next[i]的区间均合法,初始next[i]=i。
对于x,h[x]=Σ(n-next[i]+1),i=1~n。
每次x减少,需要删除不满足x的区间时,假设x的倍数的位置为b1,b2……bk,需要进行以下3种操作:
1.p>b2,next[p]=n+1
2.b1<p<=b2,next[p]=max(next[p],bk)
3.1<=p<=b1,next[p]=max(next[p],bk-1)。
容易发现,next[i]是一个单调不减的数组,那么以上3个操作都可以用线段树的区间覆盖实现,答案用区间求和实现。(套路:线段树区间取max只能在单调的前提下通过区间覆盖实现)
具体而言,线段树参数需要维护max,min,delta,sum,当min>=x时直接return,当max<=x时直接修改。
复杂度O(n log n)。
#include<cstdio> #include<cstring> #include<cctype> #include<vector> #include<algorithm> #define ll long long using namespace std; int read(){ char c;int s=0,t=1; while(!isdigit(c=getchar()))if(c=='-')t=-1; do{s=s*10+c-'0';}while(isdigit(c=getchar())); return s*t; } const int maxn=200010; struct tree{int l,r,mins,maxs,delta;ll sum;}t[maxn*4]; int n,a[maxn]; ll h[maxn]; vector<int>v[maxn]; int min(int a,int b){return a<b?a:b;} int max(int a,int b){return a>b?a:b;} void up(int k){ t[k].sum=t[k<<1].sum+t[k<<1|1].sum; t[k].mins=min(t[k<<1].mins,t[k<<1|1].mins); t[k].maxs=max(t[k<<1].maxs,t[k<<1|1].maxs); } void modify(int k,int x){t[k].delta=x;t[k].mins=t[k].maxs=x;t[k].sum=1ll*x*(t[k].r-t[k].l+1);} void down(int k){ if(t[k].delta){ modify(k<<1,t[k].delta); modify(k<<1|1,t[k].delta); t[k].delta=0; } } void build(int k,int l,int r){ t[k].l=l;t[k].r=r;t[k].delta=0; if(l==r){ t[k].maxs=t[k].mins=t[k].sum=l; } else{ int mid=(l+r)>>1; build(k<<1,l,mid); build(k<<1|1,mid+1,r); up(k); } } void fix(int k,int l,int r,int x){ if(l>r||t[k].mins>=x)return; if(l<=t[k].l&&t[k].r<=r&&t[k].maxs<=x){modify(k,x);return;} down(k); int mid=(t[k].l+t[k].r)>>1; if(l<=mid)fix(k<<1,l,r,x); if(r>mid)fix(k<<1|1,l,r,x); up(k); } int main(){ n=read(); int mx=0; for(int i=1;i<=n;i++){ a[i]=read(); for(int j=1;j*j<=a[i];j++)if(a[i]%j==0){ v[j].push_back(i); if(j*j!=a[i])v[a[i]/j].push_back(i); } mx=max(mx,a[i]); } build(1,1,n); for(int i=mx+1;i>=1;i--){ if(v[i].size()>=2){ fix(1,v[i][1]+1,n,n+1); fix(1,v[i][0]+1,v[i][1],v[i][v[i].size()-1]); fix(1,1,v[i][0],v[i][v[i].size()-2]); } h[i]=1ll*n*(n+1)-t[1].sum; } ll ans=0; for(int i=1;i<=mx;i++)ans+=1ll*(i-1)*(h[i]-h[i-1]); printf("%lld",ans); return 0; }