LG7882 [Ynoi2006] rsrams【阈值法,分块,莫队】
给定长为 \(n\) 的序列 \(a_1,\cdots,a_n\),\(m\) 次询问区间 \([L,R]\),求其所有子区间的绝对众数之和。
\(n,m\le 10^6\),\(1\le a_i\le n\),时限 \(8.0\text{s}\)。
若固定绝对众数是 \(x\),要求多少个子区间的 \(2[a_i=x]-1\) 之和 \(>0\),取前缀和之后问题就是区间顺序对计数。
优化当然考虑对出现次数 \(c\) 根号分治。对于总和 \(>0\) 的区间 \([l,r]\),在前缀和 \(=0\) 的位置分段,每段都满足第一个元素是 \(1\) 且所有前缀和 \(\ge 0\),或者左右翻转的情况。所以我们枚举每个 \(1\) 的位置 \(p\),向前后分别扩展出后缀和/前缀和均 \(\ge 0\) 的极长区间 \((l,p]\) 和 \([p,r)\),将区间 \([l,r]\) 合并到同一个连通块,此时顺序对只会在一个连通区间内出现,且这样的区间的总长是 \(\mathcal O(c)\) 的。
对于长度 \(>\sqrt m\) 的区间,只有 \(\mathcal O(n/\sqrt m)\) 个,分别计算贡献到询问,就是在这个区间内做 \(m\) 次查询的莫队,注意不能直接对询问排序,你需要事先将所有询问按右端点排序,每次做莫队时对左端点所在块做桶排序,时间复杂度 \(\mathcal O(n\sqrt m)\)。
对于长度 \(\le\sqrt m\) 的区间,总的顺序对个数只有 \(\mathcal O(n\sqrt m)\),查询次数 \(\mathcal O(\sqrt m)\),使用 \(\mathcal O(1)\) 修改 \(\mathcal O(\sqrt n)\) 查询的分块维护二维数点,时间复杂度 \(\mathcal O(n\sqrt m+m\sqrt n)\)。
#include<bits/stdc++.h>
#define fi first
#define se second
using namespace std;
typedef long long LL;
typedef pair<int, int> pii;
const int N = 1000003;
int n, m, nb, mb, blen, bl[N], a[N], sum[N * 2], *buc = sum + N;
vector<int> v[N], ins[N], del[N];
LL ans[N];
struct Query {
int l, r, id;
Query(int _1 = 0, int _2 = 0, int _3 = 0): l(_1), r(_2), id(_3){}
bool operator < (const Query &o) const {return r < o.r;}
} q[N];
vector<pii> tq[N];
struct BIT {
LL val[N], sum[1003];
void upd(int p, int v){val[p] += v; sum[p / nb] += v;}
LL qry(int p){
LL res = 0;
for(int i = p / nb - 1;i >= 0;-- i) res += sum[i];
for(int i = p / nb * nb;i <= p;++ i) res += val[i];
return res;
}
} tr;
int main(){
ios::sync_with_stdio(0);
cin >> n >> m; nb = sqrt(n); mb = sqrt(m);
for(int i = 1;i <= n;++ i){
cin >> a[i]; v[a[i]].push_back(i);
}
for(int i = 1;i <= m;++ i){
cin >> q[i].l >> q[i].r; q[i].id = i;
tq[q[i].l - 1].emplace_back(q[i].r, i);
tq[q[i].r].emplace_back(q[i].r, i);
}
sort(q + 1, q + m + 1);
vector<Query> lar, sma;
for(int i = 1;i <= n;++ i) if(!v[i].empty()){
vector<pii> t0, t1, t2;
int nl = n + 1, nr = n + 1;
for(int j = (int)v[i].size() - 1;j >= 0;-- j)
if(v[i][j] < nl){
if(nr <= n) t0.emplace_back(nl, nr);
nr = v[i][j]; nl = max(nr - 1, 1);
} else nl = max(nl - 2, 1);
t0.emplace_back(nl, nr);
reverse(t0.begin(), t0.end());
nl = nr = 0;
for(int j = 0;j < v[i].size();++ j)
if(v[i][j] > nr){
if(nl) t1.emplace_back(nl, nr);
nl = v[i][j]; nr = min(nl + 1, n);
} else nr = min(nr + 2, n);
t1.emplace_back(nl, nr);
int p0 = 0, p1 = 0; nl = nr = 0;
while(p0 < t0.size() || p1 < t1.size()){
if(p0 != t0.size() && (p1 == t1.size() || t0[p0].fi < t1[p1].fi)){
if(nr < t0[p0].fi){
if(nl) t2.emplace_back(nl, nr);
nl = t0[p0].fi; nr = t0[p0].se;
} else nr = max(nr, t0[p0].se);
++ p0;
} else {
if(nr < t1[p1].fi){
if(nl) t2.emplace_back(nl, nr);
nl = t1[p1].fi; nr = t1[p1].se;
} else nr = max(nr, t1[p1].se);
++ p1;
}
}
t2.emplace_back(nl, nr);
for(const auto &[L, R] : t2) (R - L >= mb ? lar : sma).emplace_back(L, R, i);
}
for(const auto &[ql, qr, val] : lar){
vector<Query> nq;
for(int i = 1;i <= m;++ i)
if(q[i].l <= qr && q[i].r >= ql)
nq.emplace_back(max(q[i].l, ql), min(q[i].r, qr), q[i].id);
if(nq.empty()) continue;
blen = max(1., (qr - ql + 1) / sqrt(nq.size()));
for(int i = ql;i <= qr;++ i) bl[i] = (i - ql) / blen + 1;
memset(sum, 0, (bl[qr] + 1) << 2);
for(int i = 0;i < nq.size();++ i) ++ sum[bl[nq[i].l]];
for(int i = 1;i <= bl[qr];++ i) sum[i] += sum[i - 1];
vector<Query> nxtq(nq.size());
for(int i = 0;i < nq.size();++ i) nxtq[sum[bl[nq[i].l] - 1] ++] = nq[i];
nq.swap(nxtq); buc[0] = 1;
int nl = ql, nr = ql - 1, sl = 0, sr = 0, ssl = 0, ssr = 0; LL res = 0;
for(const auto &[l, r, id] : nq){
while(nr < r){
if(a[++ nr] == val) ssr += buc[sr ++];
else ssr -= buc[-- sr];
res += ssr; ++ buc[sr]; if(sl < sr) ++ ssl;
}
while(nl > l){
if(a[-- nl] == val) ssl += buc[sl --];
else ssl -= buc[++ sl];
res += ssl; ++ buc[sl]; if(sl < sr) ++ ssr;
}
while(nr > r){
res -= ssr; -- buc[sr]; if(sl < sr) -- ssl;
if(a[nr --] == val) ssr -= buc[-- sr];
else ssr += buc[sr ++];
}
while(nl < l){
res -= ssl; -- buc[sl]; if(sl < sr) -- ssr;
if(a[nl ++] == val) ssl -= buc[++ sl];
else ssl += buc[sl --];
}
ans[id] += res * val;
}
int len = qr - ql + 1; memset(buc - len, 0, (len + 1) << 3);
}
for(int i = 0;i < sma.size();++ i){
ins[sma[i].l].push_back(i);
del[sma[i].r + 1].push_back(i);
}
set<int> st;
for(int i = 1;i <= n;++ i){
for(int j : ins[i]) st.insert(j);
for(int j : del[i]) st.erase(j);
for(int j : st)
for(int k = i, nv = 0;k <= sma[j].r;++ k)
if((nv += (a[k] == sma[j].id ? 1 : -1)) > 0)
tr.upd(k, sma[j].id);
for(auto [j, id] : tq[i]) ans[id] += (i == j ? 1 : -1) * tr.qry(j);
}
for(int i = 1;i <= m;++ i) cout << ans[i] << '\n';
}