【题解】[CodePlus 2017 11 月赛]Yazid 的新生舞会
题目分析:
如果看到部分分的话大概可以想到枚举每一个数,然后去找它为绝对众数的区间的个数,因为一个区间的绝对众数最多只有一个,所以这样算并不会算重。
那么如果我们枚举了哪个数是绝对众数,我们就可以将这个数的权值设置为 \(1\),其他数的权值设置为 \(-1\),这样对于一个区间若区间和大于零则意味着这个数是这个区间的绝对众数。
我们记 \(sum_i\) 表示权值的前缀和,那么一个区间符合条件即 \(sum_r - sum_l > 0\)。
但是我们发现这样还是不行,总不能一个个去枚举吧。但是我们发现一点性质,对于这个数出现的两个位置的中间部分的 \(sum\) 一定是一个等差数列,因为这里面每一个数的权值都是 \(-1\),所以是等差数列,而因为我们要求小于 \(sum_i\) 的个数,所以对于 \(sum\) 的维护我们可以考虑维护一个桶,那么就可以直接考虑使用线段树维护区间加等差数列,这样对于 \(sum\) 我们就可以做到高效维护了,下文就设 \(cnt_i\) 表示 \(sum_j \le i\) 的 \(j\) 的数量。
假设我们现在的区间为 \([l,r]\),\(sum_l = x\) 且 \(sum_r = y\),因为上文的分析也就是说 \([l,r]\) 上 \(sum\) 是一个等差数列,我们这个区间的答案其实就是 \(\sum_{i=x-1}^{y-1} cnt_i\)。
剩下的就是对于 \(x,y\) 怎么求了,实际上我们可以认为每一个位置都有一个 \(-1\) 的贡献,然后对于每一次出现的位置有一个额外的 \(2\) 的贡献,就很好求了。
代码:
点击查看代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N = 5e5+5;
int val[N],a[N],tags[4*N],tagd[4*N],sum[4*N];
vector<int> v[N];
int get(int s,int d,int len){
return (s + s + d * (len - 1)) * len / 2;
}
void update(int now,int l,int r,int s,int d){
tags[now] += s;tagd[now] += d;
sum[now] += get(s,d,r-l+1);
}
void pushdown(int now,int l,int r){
if(tags[now] || tagd[now]){
int mid = (l + r)>>1;
update(now<<1,l,mid,tags[now],tagd[now]);
update(now<<1|1,mid+1,r,tags[now] + tagd[now] * (mid - l + 1),tagd[now]);
tags[now] = tagd[now] = 0;
}
}
void pushup(int now){
sum[now] = sum[now<<1] + sum[now<<1|1];
}
void modify(int now,int now_l,int now_r,int l,int r,int s,int d){
if(l <= now_l && r >= now_r){
update(now,now_l,now_r,s + (now_l - l) * d,d);
return;
}
pushdown(now,now_l,now_r);
int mid = (now_l + now_r)>>1;
if(l <= mid) modify(now<<1,now_l,mid,l,r,s,d);
if(r > mid) modify(now<<1|1,mid+1,now_r,l,r,s,d);
pushup(now);
}
int query(int now,int now_l,int now_r,int l,int r){
if(l <= now_l && r >= now_r) return sum[now];
pushdown(now,now_l,now_r);
int mid = (now_l + now_r)>>1,ans = 0;
if(l <= mid) ans += query(now<<1,now_l,mid,l,r);
if(r > mid) ans += query(now<<1|1,mid+1,now_r,l,r);
return ans;
}
signed main(){
// freopen("in.txt","r",stdin);
// freopen("out.txt","w",stdout);
int n,tp;scanf("%lld%lld",&n,&tp);
for(int i=1; i<=n; i++) scanf("%lld",&a[i]),v[a[i]].push_back(i);
for(int i=0; i<n; i++) v[i].push_back(n+1);
int ans = 0,p = n + 1;
for(int i=0; i<n; i++){
int lst = 0;
for(int j = 0; j<v[i].size(); j++){
int y = 2 * j - lst + p,x = 2 * j - (v[i][j] - 1) + p;
ans += query(1,1,2*n+1,max(x-1,1ll),y-1);
modify(1,1,2*n+1,x,y,1,1);
if(y + 1 <= 2 * n + 1) modify(1,1,2*n+1,y+1,2*n+1,(y-x+1),0);
lst = v[i][j];
}
lst = 0;
for(int j = 0; j<v[i].size(); j++){
int y = 2 * j - lst + p,x = 2 * j - (v[i][j] - 1) + p;
modify(1,1,2*n+1,x,y,-1,-1);
if(y + 1 <= 2 * n + 1) modify(1,1,2*n+1,y+1,2*n+1,-(y-x+1),0);
lst = v[i][j];
}
}
printf("%lld\n",ans);
return 0;
}