莫队学习笔记
引入问题
给出一个长度为 \(n\) 的数组 \(A\),接下来 \(q\) 次询问,每次询问求 \([l,r]\) 中有多少组 \((i,j,k)\) 使得 \(a_i=a_j=a_k(i<j<k)\)。
保证 \(1\leq n\leq 10^5,1\leq A_i\leq n(1\leq i\leq n)\)。
莫队的基础思想——区间转移
简单分析问题,貌似并没有可加性,所以分块和线段树肯定寄了。
但是本题中相邻区间转移是 \(O(1)\):
- 对于增加元素,设原有此元素 \(n\) 个,则增加三元组 \(\dfrac 1 2n(n-1)\) 个。
- 对于删除元素,设减去该元素后有此元素 \(n\) 个,则减少三元组 \(\dfrac 1 2n(n-1)\) 个。
莫队算法是一种解决此类问题的离线算法,它基于以下思想:
我们不难考虑到一个暴力代码,它的框架如下:
void del(LL x)
{
cnt[a[x]]--;
sum-=cnt[a[x]]*(cnt[a[x]]-1)/2;
}
void ins(LL x)
{
sum+=cnt[a[x]]*(cnt[a[x]]-1)/2;
cnt[a[x]]++;
}
LL l=1,r=0,sl,sr;
for(int i=1;i<=q;i++)
{
sl=b[i].l,sr=b[i].r;
while(l<sl)del(l++);
while(sl<l)ins(--l);
while(r<sr)ins(++r);
while(sr<r)del(r--);
ans[b[i].id]=sum;
}
该代码中出现的函数,ins(x)
表示算上 \(x\) 这一项的贡献,del(x)
表示删除 \(x\) 这一项的贡献。
不难看出,这份代码本质上是对于询问区间的移动。
如果我们保证区间移动次数较少的话,时间复杂度也会比较优秀。
我们有什么思路呢?
时间复杂度优秀的秘密——分块
我们取 \(B=\sqrt n\) 为块长进行分块,然后,按照左边界所在的块的编号给询问区间排序,同一个块则用右边界排序。
不难得到如下代码:
bool cmp(node x,node y)
{
if(x.l/B==y.l/B)return x.r<y.r;
return x.l/B<y.l/B;
}
我们来分析一下时间复杂度:
- 左边界块内移动,每次时间复杂度 \(O(\sqrt n)\),有 \(q\) 次,时间复杂度为 \(O(q\sqrt n)\)。
- 左边界越块移动,每次翻过一个块的时间复杂度是 \(O(\sqrt n)\),有 \(\sqrt n\) 次,时间复杂度为 \(O(n)\)。
- 对于每个左边界的块,右边界的移动整体来看都是 \(O(n)\) 的,有 \(\sqrt n\) 个块,时间复杂度为 \(O(n\sqrt n)\)。
因此,莫队的整体时间复杂度是 \(O(n\sqrt n)\)。
玄学优化——奇偶性优化
对于编号为奇数的块,我们用右边界从小到大排序,对于编号为偶数的块,我们用右边界从大到小排序。
这样会快一点,因为不加优化之前来到新的块右端点需要先回溯至这个块里最小的右端点。
但是加了这个优化,有时我们可以从原先的最右边从右往左依次处理新的块的询问,所以会快一点。
代码实现
#include<bits/stdc++.h>
#define LL long long
using namespace std;
const LL N=2e5+5;
struct node
{
LL l,r,id;
}b[N];
LL n,q,B,a[N],cnt[N],ans[N],sum;
bool cmp(node x,node y)
{
if(x.l/B==y.l/B)
{
if((x.l/B)&1)return x.r<y.r;
return x.r>y.r;
}
return x.l/B<y.l/B;
}
void del(LL x)
{
cnt[a[x]]--;
sum-=cnt[a[x]]*(cnt[a[x]]-1)/2;
}
void ins(LL x)
{
sum+=cnt[a[x]]*(cnt[a[x]]-1)/2;
cnt[a[x]]++;
}
int main()
{
scanf("%lld%lld",&n,&q);
B=sqrt(n);
for(int i=1;i<=n;i++)
{
scanf("%lld",&a[i]);
}
for(int i=1;i<=q;i++)
{
scanf("%lld%lld",&b[i].l,&b[i].r);
b[i].id=i;
}
sort(b+1,b+q+1,cmp);
LL l=1,r=0,sl,sr;
for(int i=1;i<=q;i++)
{
sl=b[i].l,sr=b[i].r;
while(l<sl)del(l++);
while(sl<l)ins(--l);
while(r<sr)ins(++r);
while(sr<r)del(r--);
ans[b[i].id]=sum;
}
for(int i=1;i<=q;i++)
{
printf("%lld\n",ans[i]);
}
}
如果觉得不错的话,就给一个赞吧!
作者是 DengDuck ,转载请注明出处
文章链接: https://www.cnblogs.com/dengduck/p/17519968.html
感谢您阅读!