【xsy2274】 平均值 线段树
题目大意:给你一个长度为$n$的序列$a$,请你求:
$\sum\limits_{l=1}^{n}\sum\limits_{r=l}^{n}\dfrac{mex(a_l,a_{l+1},...,a_r)}{r-l+1}$
对998244353取模
数据范围:$n≤5\times 10^5$
我们考虑把原先的式子转化一下
令$s[i]=\sum\limits_{j=1}^{i} \frac{1}{i}$。
令$f[i][l]$表示最小的$x$,满足$mex(a_l,a_{l+1},...,a_x)≥i$。若找不到这样的$x$,$f[i][l]=n+1$
不难发现,原先答案的式子我们可以转化:
$\sum\limits_{i=1}^{lim}\sum\limits_{l=1}^{n}s[n-l+1]-s[f[i][l]-l]$
其中lim表示最大的数x,满足0到x-1中的数都出现过
然后我们发现,当i不变时,$f[i]$的值是递增的,且有大量的值是相同的,且以区间的形式出现。
我们可以基于这个性质,通过线段树打标记,快速地将$f[i]$的值更新至$f[i+1]$。
线段树统计的同时维护答案的式子,每次累加即可。
1 #include<bits/stdc++.h> 2 #define M (1<<19) 3 #define L long long 4 #define mid ((a[x].l+a[x].r)>>1) 5 #define MOD 998244353 6 using namespace std; 7 8 L pow_mod(L x,L k){L ans=1;for(;k;k>>=1,x=x*x%MOD) if(k&1) ans=ans*x%MOD; return ans;} 9 L inv[M]={0},s[M]={0},n; 10 L S(int l,int r){return (s[r]-s[max(l-1,0)]+MOD)%MOD;} 11 12 struct seg{int l,r,tag,minn,maxn;L sum;}a[M*2]; 13 14 void pushup(int x){ 15 a[x].minn=min(a[x<<1].minn,a[x<<1|1].minn); 16 a[x].maxn=max(a[x<<1].maxn,a[x<<1|1].maxn); 17 a[x].sum=(a[x<<1].sum+a[x<<1|1].sum); 18 } 19 void upd(int x,int k){ 20 a[x].tag=a[x].minn=a[x].maxn=k; 21 a[x].sum=S(k-a[x].r,k-a[x].l); 22 } 23 void pushdown(int x){ 24 if(a[x].tag) upd(x<<1,a[x].tag),upd(x<<1|1,a[x].tag); 25 a[x].tag=0; 26 } 27 28 int build(int x,int l,int r){ 29 a[x].l=l; a[x].r=r; if(l==r) return a[x].minn=a[x].maxn=l; 30 build(x<<1,l,mid); build(x<<1|1,mid+1,r); pushup(x); 31 } 32 33 void updata(int x,int l,int r,int k){ 34 if(a[x].minn>=k) return; 35 if(a[x].l==a[x].r){ 36 a[x].minn=a[x].maxn=max(a[x].maxn,k); 37 a[x].sum=S(k-a[x].l,k-a[x].l); 38 return; 39 } 40 if(l<=a[x].l&&a[x].r<=r){ 41 if(a[x].maxn<k){ 42 upd(x,k); 43 return; 44 } 45 } 46 pushdown(x); 47 if(l<=mid) updata(x<<1,l,r,k); 48 if(mid<r) updata(x<<1|1,l,r,k); 49 pushup(x); 50 } 51 52 struct node{ 53 int x,id; node(){x=id=0;} 54 friend bool operator <(node a,node b){return a.x==b.x?a.id<b.id:a.x<b.x;} 55 }p[M]; 56 57 int main(){ 58 for(int i=1;i<M;i++) inv[i]=pow_mod(i,MOD-2); 59 for(int i=1;i<M;i++) s[i]=(s[i-1]+inv[i])%MOD; 60 for(int i=1;i<M;i++) s[i]=(s[i-1]+s[i])%MOD; 61 62 scanf("%d",&n); build(1,1,n); 63 for(int i=1;i<=n;i++) scanf("%d",&p[i].x),p[i].id=i; 64 sort(p+1,p+n+1); p[0].x=-1; p[n+1].x=19890604; 65 66 L ans=0,hh=0; 67 68 for(int i=1,j=1;i<=n;hh++){ 69 if(p[i].x!=p[i-1].x+1) break; 70 while(p[i].x==p[j].x) j++; 71 72 for(int last=0;i<=j;i++){ 73 if(i==j){ 74 if(last<n) updata(1,last+1,n,n+1); 75 break; 76 } 77 updata(1,last+1,p[i].id,p[i].id); 78 last=p[i].id; 79 } 80 81 ans=(ans+a[1].sum)%MOD; 82 } 83 84 cout<<(s[n]*hh-ans+MOD)%MOD<<endl; 85 }