「知识学习&日常训练」莫队算法(一)(Codeforce Round #340 Div.2 E)
题意 (CodeForces 617E)
已知一个长度为\(n\)的整数数列\(a[1],a[2],…,a[n]\),给定查询参数\(l,r\),问\([l,r]\)内,有多少连续子段满足异或和等于\(k\)。
也就是说,对于所有的\(x,y (l\le x\le y\le r)\),能够满足\(a[x]\oplus a[x+1]\oplus ...\oplus a[y]=k\)的\((x,y)\)有多少组。
分析
对于这种离线区间的查询问题(不涉及对区间的更改),我们可以使用莫队算法解决。这类问题是什么类型?对于序列上的区间询问问题,如果从\([l, r]\)的答案能够\(O(1)\)扩展到\([l+1,r],[l,r−1],[l - 1, r],[l, r + 1]\)的答案,那么可以在\(O(n\sqrt n)\)的复杂度内求出所有询问的答案。
这题为什么可以?因为对于\(x\)至\(y\)的区间异或和,我们可以用前缀异或和的\(x-1\)与\(y\)相异或来解决。
接下来讲讲具体的实现:
(参考:https://blog.sengxian.com/algorithms/mo-s-algorithm)
实现:离线后排序,顺序处理每个询问,暴力从上一个区间的答案转移到下一个区间答案。
排序方法:设定块的长度为\(S\),按照\((\lfloor\frac l S\rfloor, r)\)二元组从小到大排序。
复杂度分析:设序列长度为\(n\),询问个数为\(m\),块内元素共有\(k\)个。注意到上面拓展左右区间的\(O(1)\)复杂度,我们分左端点右端点讨论:
a) 右端点:由于在块内递增,所以在整个块内迁移的复杂度下界是\(O(n)\);对于跨块而言最多从\(n\)迁回\(1\),所以也是\(O(n)\)(注意一下,这里讨论的维度是块,因为我们保证了块内的递增,所以一个块内的元素最多迁移\(n\)次);一共有\(\frac{n}{k}\)块,故总复杂度是\(O(\frac{n^2}{k})\);
b) 左端点:注意到我们分块时不维护递增,所以每次询问最多迁移\(k\)次;共有\(m\)次询问,故总复杂度是\(O(km)\)。
综上,莫队算法的总复杂度是\(O(\frac{n^2}{k}+km)\)。当这里\(k\)是变量,当\(m,n\)同一数量级时,取\(k = \sqrt n\)有最优复杂度\(O(n\sqrt n)\)。
这题的具体实现:我们记\(mp[x]\)为异或和为x的个数。转移区间的时候(不失一般性,考虑区间纯右移),每增加一个点\(r\),这个点对于答案的贡献是\(mp[x\oplus a[r]]\)(异或的性质),同时,它增加了\(mp[a[r]]\)的个数。每减少一个点同理。
代码
参考:https://blog.csdn.net/swust_lian/article/details/50615109
/*
* Filename: cfr340d2e.cpp
* Date: 2018-11-09
*/
#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
#define PB emplace_back
#define MP make_pair
#define fi first
#define se second
#define rep(i,a,b) for(repType i=(a); i<=(b); ++i)
#define per(i,a,b) for(repType i=(a); i>=(b); --i)
#define ZERO(x) memset(x, 0, sizeof(x))
#define MS(x,y) memset(x, y, sizeof(x))
#define ALL(x) (x).begin(), (x).end()
#define QUICKIO \
ios::sync_with_stdio(false); \
cin.tie(0); \
cout.tie(0);
#define DEBUG(...) fprintf(stderr, __VA_ARGS__), fflush(stderr)
using namespace std;
using pi=pair<int,int>;
using repType=int;
using ll=long long;
using ld=long double;
using ull=unsigned long long;
const int MAXN=100005;
const int BLOCK=400;
struct Node
{
ll l,r,id;
Node(ll _l=0, ll _r=0, ll _id=0):
l(_l), r(_r), id(_id) {}
bool operator < (const Node& rhs) const
{
if(l/BLOCK!=rhs.l/BLOCK) return l/BLOCK<rhs.l/BLOCK;
else return r<rhs.r;
}
};
vector<Node> vec;
ll s[MAXN];
ll ans[MAXN], mp[MAXN*200];
ll n,m,k;
int
main()
{
scanf("%lld%lld%lld", &n, &m, &k);
s[0]=0;
rep(i,1,n)
{
ll x; scanf("%lld", &x);
s[i]=s[i-1]^x;
}
rep(i,1,m)
{
ll l,r;
scanf("%lld%lld", &l, &r);
vec.PB(l-1,r,i); // why l-1: xor(a[x]~a[y])=k <-> s[x-1]^s[y]=k
}
sort(ALL(vec));
ZERO(mp);
ZERO(ans);
ll tmp=0;
int l=vec[0].l, r=vec[0].r;
rep(i,l,r)
{
tmp+=mp[s[i]^k];
mp[s[i]]++;
}
ans[vec[0].id]=tmp;
rep(i,1,m-1)
{
int L=vec[i].l,
R=vec[i].r;
while(l>L)
{
l--;
tmp+=mp[s[l]^k];
mp[s[l]]++; // mp: cnt of xor_sum = s[l]
}
while(l<L)
{
mp[s[l]]--;
tmp-=mp[s[l]^k];
l++;
}
while(r<R)
{
r++;
tmp+=mp[s[r]^k];
mp[s[r]]++;
}
while(r>R)
{
mp[s[r]]--;
tmp-=mp[s[r]^k];
r--;
}
ans[vec[i].id]=tmp;
}
rep(i,1,m) printf("%lld\n", ans[i]);
return 0;
}