莫队算法总结
莫队算法 总结
最近两天学习了一下莫队,感觉莫队算法还是挺好用的(现在看到离线询问就想莫队...
就稍微写一下总结吧,加深一下对算法的理解。
-
普通莫队
核心思想:莫队算法一般用来离线处理一系列无修改的区间询问问题,通过将所有的询问保存下来,并且将所有的询问区间进行适当地排序,从而达到降低时间复杂度的效果。
对于所有的询问区间\([l_i,r_i]\),如果暴力地进行区间端点移动,那么对于一次询问,区间端点可能移动\(n\)的长度。假设询问的规模与\(n\)同级,那么复杂度就为\(O(n^2)\)。
但其实,我们可以巧妙地安排区间顺序以降低时间复杂度。
莫队算法的思想如下:
将区间分为\(\sqrt{n}\)块,每块的长度也为\(\sqrt{n}\),之后对所有的询问区间排序,如果区间左端点在同一块内,则按右端点排序;否则则按左端点所在块进行排序。
就这样排序过后,暴力计算就行了,可以证明,时间复杂度为\(O(n^{\frac{3}{2}})\)。
下面给出简单的证明:
假设区间左端点在同一块内,那么一次询问左端点最多移动\(\sqrt{n}\),由于右端点是单增的,则右端点移动总的复杂度为\(O(n)\),此时端点移动的总复杂度为\(O(n^{\frac{3}{2}})\)。(注意这里是均摊意义上的复杂度)
如果区间左端点不在同一块,也就是左端点跨块移动,因为一共有\(\sqrt{n}\)块,每次右端点的移动最多\(O(n)\),此时总的时间复杂度也为\(O(n^{\frac{3}{2}})\)。
所以经过分块过后,时间复杂度可以降为\(O(n^\frac{3}{2})\)。
可以先通过几道例题感受一下:
洛谷P1494 小Z的袜子
设\(cnt[i]\)为第\(i\)种颜色的袜子的个数,当前区间为\([l,r]\),那么容易知道所求的答案为\(\frac{\sum_{i=1}^{k}{C_{cnt[i]}^2}}{C_{r-l+1}^{2}}\)。
因为分母是与区间长度有关,我们只用考虑区间端点变化时,分子的变化情况就行了。
先单独把分子拿出来:\(\sum_{i=1}^{k}{C_{cnt[i]}^2}\),当区间范围增加一时,会存在一个\(t\),有\(cnt[t]+1\),在这个求和式中,其余项不会改变,那么我们就只用看这一项对答案的影响。
影响即为:\(C_{cnt[t]+1}^2-C_{cnt[t]}^2\),那么在进行答案更新时算一下这个式子就好了。
对于区间范围减小的情况分析也同理。
代码如下:
Code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 50005;
int n, m, block;
int a[N];
struct query{
int l, r, id ;
}Q[N];
struct Ans{
ll p, q;
}answer[N];
bool cmp(query x, query y) {
if((x.l - 1) / block + 1 == (y.l - 1) / block + 1) return x.r < y.r;
return (x.l - 1) / block + 1 < (y.l - 1) / block + 1 ;
}
ll gcd(ll A, ll B) {
return B == 0 ? A : gcd(B, A % B) ;
}
ll ans ;
ll cnt[N] ;
void update(int pos, int sign) {
ans -= cnt[a[pos]] * cnt[a[pos]] ;
cnt[a[pos]] += sign ;
ans += cnt[a[pos]] * cnt[a[pos]] ;
}
int main() {
scanf("%d%d",&n, &m) ;
block = (int)sqrt(n) ;
for(int i = 1; i <= n; i++) scanf("%d", &a[i]) ;
for(int i = 1; i <= m; i++) {
scanf("%d%d",&Q[i].l, &Q[i].r) ;
Q[i].id = i ;
}
sort(Q + 1, Q + m + 1, cmp) ;
int l = 1, r = 0;
for(int i = 1; i <= m; i++) {
for(; r < Q[i].r; r++) update(r + 1, 1) ;
for(; r > Q[i].r; r--) update(r, -1) ;
for(; l < Q[i].l; l++) update(l, -1) ;
for(; l > Q[i].l; l--) update(l - 1, 1) ;
answer[Q[i].id].p = ans - Q[i].r + Q[i].l - 1;
answer[Q[i].id].q = 1ll * (Q[i].r - Q[i].l + 1) * (Q[i].r - Q[i].l) ;
if(Q[i].l == Q[i].r) answer[Q[i].id].p = 0, answer[Q[i].id].q = 1;
ll g = gcd(answer[Q[i].id].p, answer[Q[i].id].q) ;
answer[Q[i].id].p /= g; answer[Q[i].id].q /= g;
}
for(int i = 1; i <= m; i++) printf("%lld/%lld\n",answer[i].p, answer[i].q) ;
return 0;
}
给出的数字串挺长的,但是质数\(p\)不是很大。
我们知道,如果一个数字\(t\)为\(p\)的倍数,那么就有\(t\mod p=0\)。但是区间中的子串很多,我们直接时间复杂度等同于暴力。所以我们可以考虑将问题转化一下。
设串\(s\)所在区间为\([l,r]\),串的长度为\(n\),那么我们知道\(s*10^{r-l+1}=t[l,l+1,\cdots,n]-t[r+1,r+2,\cdots,n]*10^{r-l+1}\)。
所以当质数\(p\)不为2和5时,\(s\mod p=0 => (t[l,l+1,\cdots,n]-t[r+1,r+2,\cdots,n]*10^{r-l+1}) \mod p=0 => t[l,l+1,\cdots,n]\mod p=t[r+1,r+2,\cdots,n]\mod p\)。
所以我们就可以维护一个数组\(f[i]\),表示后缀\(i\)对\(p\)取余的值为多少,那么我们就可以将一个区间为\([l,r]\)的询问转化为\([l,r+1]\)中有多少对\(f\)相等了。
之后就用莫队来搞,计算区间范围增加或者减小对答案的影响就好了。思路同上一题类似。
对于\(p\)为2或者5的情况,特判一波,维护前缀个数就好了。
代码如下:
Code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 2e5 + 5;
ll p, cnt;
int block, n;
char s[N] ;
ll f[N], d[N];
ll num[N] ;
ll sum[N][3] ;
struct Query{
int l, r, id;
ll ans ;
}q[N];
int Q;
bool cmp(Query x, Query y) {
if((x.l - 1) / block + 1 == (y.l - 1) / block + 1) return x.r < y.r;
return (x.l - 1) / block + 1 < (y.l - 1) / block + 1 ;
}
bool cmp2(Query x, Query y) {
return x.id < y.id ;
}
void sol1() {
for(int i = 1; i <= n; i++) {
sum[i][0] = sum[i - 1][0] ;
sum[i][1] = sum[i - 1][1] ;
sum[i][2] = sum[i - 1][2] ;
if(p == 2 && (s[i] - '0') % p == 0) sum[i][0] += i, sum[i][2]++;
if(p == 5 && (s[i] - '0') % p == 0) sum[i][1] += i, sum[i][2]++;
}
for(int i = 1; i <= Q; i++) {
int l = q[i].l, r = q[i].r;
ll k;
if(p == 2) k = sum[r][0] - sum[l - 1][0] ;
else k = sum[r][1] - sum[l - 1][1] ;
q[i].ans = k - (sum[r][2] - sum[l - 1][2]) * (l - 1) ;
}
}
void update2(int pos, int sign) {
cnt -= (num[f[pos]] - 1) * num[f[pos]] / 2;
num[f[pos]] += sign;
cnt += (num[f[pos]] - 1) * num[f[pos]] / 2;
}
void sol2() {
int l = 1, r = 0 ;
for(int i = 1; i <= Q; i++) {
q[i].r += 1;
for(; r < q[i].r; r++) update2(r + 1, 1) ;
for(; r > q[i].r; r--) update2(r, -1) ;
for(; l > q[i].l; l--) update2(l - 1, 1) ;
for(; l < q[i].l; l++) update2(l, -1) ;
q[i].ans = cnt ;
}
}
int main() {
scanf("%lld%s%d", &p, s + 1, &Q) ;
n = strlen(s + 1) ;
block = (int)sqrt(n) ;
for(int i = 1; i <= Q; i++) {
scanf("%d%d",&q[i].l, &q[i].r) ;
q[i].id = i;
}
sort(q + 1, q + Q + 1, cmp) ;
ll x = 0, qp = 1;
int flag = -1;
for(int i = n; i >= 1; i--) {
x = (x + (s[i] - '0') * qp % p) % p;
d[i] = f[i] = x;
if(f[i] == 0) flag = i;
qp = qp * 10 % p;
}
sort(d + 1, d + n + 1) ;
int D = unique(d + 1, d + n + 1) - d - 1;
for(int i = 1; i <= n; i++) f[i] = lower_bound(d + 1, d + n + 1, f[i]) - d;
if(flag > 0) f[n + 1] = f[flag] ;
if(p == 2 || p == 5) sol1() ;
else sol2() ;
sort(q + 1, q + Q + 1, cmp2) ;
for(int i = 1; i <= Q; i++) printf("%lld\n", q[i].ans) ;
return 0;
}
感觉这个题挺好的。没想到还可以用莫队来搞。
对于区间\([l,r]\),假设我们要将\(r\)增加1,那么就会多出\(r-l+2\)个序列,我们就分析他们对答案的影响。
假设区间\([l,r]\)中最小值所在位置为\(p\),那么很显然,左端点在\([l,l+1,\cdots,p]\)时,区间最小值就为\(a[p]\)。
对于\(r+1\)而言,如果我们找到左边第一个比他小的位置为\(k\),那么此时对答案的贡献就为\((r-k+2)*a[k]\);同理对\(k\)也可以执行同样的操作。最后必然会存在一个位置\(q\),其左边第一个比他小的位置为\(q\),那么操作在这里就终止了。
每次这么操作时间复杂度过高,发现可以维护一个类似于前缀和一样的东西,递推地来维护就行了。设该前缀和函数为\(f\),那么区间右端点增加一位对答案的贡献为:\(a[p]*(p-l+1)+f[r+1]-f[p]\)。
这样就可以O(1)算出对答案的影响了。
左端点的情况也类似考虑。
代码如下:
Code
#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
using namespace std;
typedef long long ll;
const int N = 2e5 + 5;
int n, m, block;
int a[N];
struct Query{
int l, r, id;
ll ans;
}q[N];
bool cmp(Query A, Query B) {
if((A.l - 1) / block + 1 == (B.l - 1) / block + 1) return A.r < B.r;
return (A.l - 1) / block + 1< (B.l - 1) / block + 1;
}
bool cmp_id(Query A, Query B) {
return A.id < B.id ;
}
int l[N], r[N] ;
int sta[N], top;
ll f12[N], f21[N];
int f[N][22], pos[N][22], Log2[N];
ll ans ;
int Get_min(int L, int R) {
ll k = Log2[R - L + 1];
if(f[L][k] > f[R - (1LL << k) + 1][k]) return pos[R - (1LL << k) + 1][k] ;
return pos[L][k] ;
}
void update1(int pos, int L, int R, int sign) {
int p = Get_min(L, R) ;
ll sum = f12[R] - f12[p] + 1ll * (p - L + 1) * a[p];
ans += 1ll * sign * sum;
}
void update2(int pos, int L, int R, int sign) {
int p = Get_min(L, R) ;
ll sum = f21[L] - f21[p] + 1ll * (R - p + 1) * a[p] ;
ans += 1ll * sign * sum ;
}
int main() {
ios::sync_with_stdio(false); cin.tie(0);
cin >> n >> m ;
Log2[1] = 0;
for(int i = 2; i <= n; i++) Log2[i] = Log2[i >> 1] + 1;
block = sqrt(n) ;
memset(f, INF, sizeof(f)) ;
for(int i = 1; i <= n; i++) {
cin >> a[i] ;
f[i][0] = a[i] ;
pos[i][0] = i ;
}
for(int j = 1; j <= 17; j++) {
for(int i = 1; i + (1 << (j - 1)) <= n; i++) {
if(f[i][j - 1] > f[i + (1 << (j - 1))][j - 1]) {
f[i][j] = f[i + (1 << (j - 1))][j - 1] ;
pos[i][j] = pos[i + (1 << (j - 1))][j - 1] ;
} else {
f[i][j] = f[i][j - 1];
pos[i][j] = pos[i][j - 1] ;
}
}
}
for(int i = 1; i <= n + 1; i++) {
while(top > 0 && a[sta[top]] >= a[i]) r[sta[top--]] = i ;
sta[++top] = i;
}
top = 0;
for(int i = n; i >= 0; i--) {
while(top > 0 && a[sta[top]] >= a[i]) l[sta[top--]] = i;
sta[++top] = i;
}
for(int i = 1; i <= n; i++)
f12[i] = f12[l[i]] + 1ll * (i - l[i]) * a[i] ;
for(int i = n; i >= 1; i--)
f21[i] = f21[r[i]] + 1ll * (r[i] - i) * a[i] ;
for(int i = 1; i <= m; i++) {
int L, R;
cin >> L >> R;
q[i].l = L; q[i].r = R;
q[i].id = i;
}
sort(q + 1, q + m + 1, cmp) ;
int L = 1, R = 0;
for(int i = 1; i <= m; i++) {
for(; R < q[i].r; R++) update1(R + 1, L, R + 1, 1) ;
for(; R > q[i].r; R--) update1(R, L, R, -1) ;
for(; L < q[i].l; L++) update2(L, L, R, -1) ;
for(; L > q[i].l; L--) update2(L - 1, L - 1, R, 1) ;
q[i].ans = ans ;
}
sort(q + 1, q + m + 1, cmp_id) ;
for(int i = 1; i <= m; i++)
cout << q[i].ans << '\n' ;
return 0;
}
-
带修改莫队
之前说的莫队是不支持修改的,但其实也可以支持修改,只需要再加一维“时间状态”就行了,对于每个询问,新增一维,变为\([l,r,k]\),表示当前区间为\([l,r]\),之前经过\(k\)次修改操作的询问。
为什么这样是正确的呢?
因为我们如果知道了\([l,r,k]\)的答案,那么就很容易知道\([l+1,r,k],[l-1,r,k],[l,r-1,k],[l,r+1,k],[l,r,k-1],[l,r,k+1]\)对答案的影响。
具体来说,修改时间维度时,看看修改的位置是否在\([l,r]\)中,如果在则会对答案产生影响,否则直接修改就是了。之后区间端点左右移动时,遇到的位置也一定是完成\(k\)次修改过后的值了。
此时我们还是将区间进行分块,但现在要分为\(n^\frac{2}{3}\)块,每块长度为\(n^\frac{1}{3}\)。然后以左端点所在的块为第一关键字,右端点所在的块为第二关键字,修改次数为第三关键字进行排序。
可以证明这样的时间复杂度是\(O(n^\frac{5}{3})\)的。
证明方法就类似于上面的分析。
来看一道例题:
这就是个待修改莫队的模板题,多了一维对时间的修改,详细见代码吧:
Code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 50005, MAX = 1e6 + 5;
int n, m, block, num, M, l, r, t, ans;
char ss[5];
int c[N], cnt[MAX], last[N];
struct Upd{
int pos, v;
}upd[N];
struct query{
int l, r, ans, id, k;
}q[N];
bool cmp(query a, query b) {
if((a.l - 1) / block == (b.l - 1) / block && (a.r - 1) / block == (b.r - 1) / block) return a.k < b.k;
else if((a.l - 1) / block == (b.l - 1) / block) return (a.r - 1) / block < (b.r - 1) / block ;
return (a.l - 1) / block < (b.l - 1) / block;
}
bool cmp_id(query a, query b) {
return a.id < b.id;
}
void update_add(int T) {
int pos = upd[T].pos, v = upd[T].v;
last[T] = c[pos] ;
if(l <= pos && pos <= r) {
cnt[c[pos]]--;
if(cnt[c[pos]] == 0) ans--;
cnt[v]++;
if(cnt[v] == 1) ans++;
}
c[pos] = v;
}
void update_del(int T) {
int pos = upd[T].pos, v = upd[T].v;
if(l <= pos && pos <= r) {
cnt[v]--;
if(cnt[v] == 0) ans--;
c[pos] = last[T] ;
cnt[c[pos]]++;
if(cnt[c[pos]] == 1) ans++;
} else c[pos] = last[T] ;
}
void update(int pos, int val) {
cnt[c[pos]] += val;
if(val == 1) {
if(cnt[c[pos]] == 1) ans++;
} else if(val == -1)
if(cnt[c[pos]] == 0) ans--;
}
int main() {
scanf("%d%d",&n, &m) ;
block = pow(n, 0.666666) ;
for(int i = 1; i <= n; i++) scanf("%d", &c[i]) ;
for(int i = 1; i <= m; i++) {
scanf("%s",ss) ;
if(ss[0] == 'R') {
int pos, v;
scanf("%d%d",&pos, &v) ;
upd[++num].pos = pos; upd[num].v = v ;
} else {
int l, r;
scanf("%d%d",&l, &r) ;
q[++M].l = l; q[M].r = r;
q[M].id = M; q[M].k = num;
}
}
sort(q + 1, q + M + 1, cmp) ;
l = 1, r = 0, t = 0;
for(int i = 1; i <= M; i++) {
for(; t < q[i].k; t++) update_add(t + 1) ;
for(; t > q[i].k; t--) update_del(t) ;
for(; r < q[i].r; r++) update(r + 1, 1) ;
for(; r > q[i].r; r--) update(r, -1) ;
for(; l < q[i].l; l++) update(l, -1) ;
for(; l > q[i].l; l--) update(l - 1, 1) ;
q[i].ans = ans ;
}
sort(q + 1, q + M + 1, cmp_id) ;
for(int i = 1; i <= M; i++) printf("%d\n", q[i].ans) ;
return 0 ;
}
-
树上莫队
如果可以对树进行分块的话,那么也可以对树上的询问用莫队来搞。刚好有一道树上分块的模板题。
那么树上莫队的具体做法就为,首先将树进行分块,然后对所有的询问\([x,y]\),首先让\(x\)的时间戳小于\(y\)的时间戳,然后就按照\(x\)所在的块为第一关键字,以y的时间戳为第二关键字进行排序就好了。
之后考虑询问间的转移,方法为直接将\(x_i->x_{i+1}\)路径上面的所有点除开它们lca的状态取反,同理也将\(y_i->y_{i+1}\)路径上面的所有点除开它们lca的状态取反,计算答案就是了。
具体证明直接引用vfk的博客:
用S(v, u)代表 v到u的路径上的结点的集合。
用root来代表根结点,用lca(v, u)来代表v、u的最近公共祖先。
那么
S(v, u) = S(root, v) xor S(root, u) xor lca(v, u)
其中xor是集合的对称差。
简单来说就是节点出现两次消掉。
lca很讨厌,于是再定义
T(v, u) = S(root, v) xor S(root, u)
观察将curV移动到targetV前后T(curV, curU)变化:
T(curV, curU) = S(root, curV) xor S(root, curU)
T(targetV, curU) = S(root, targetV) xor S(root, curU)
取对称差:
T(curV, curU) xor T(targetV, curU)= (S(root, curV) xor S(root, curU)) xor (S(root, targetV) xor S(root, curU))
由于对称差的交换律、结合律:
T(curV, curU) xor T(targetV, curU)= S(root, curV) xorS(root, targetV)
两边同时xor T(curV, curU):
T(targetV, curU)= T(curV, curU) xor S(root, curV) xor S(root, targetV)
发现最后两项很爽……哇哈哈
T(targetV, curU)= T(curV, curU) xor T(curV, targetV)
(有公式恐惧症的不要走啊 T_T)
也就是说,更新的时候,xor T(curV, targetV)就行了。
即,对curV到targetV路径(除开lca(curV, targetV))上的结点,将它们的存在性取反即可。
因为lca我们不会算,所以最后单独考虑一下lca就行了。
这是树上莫队的第一种解法,另外还有一种就是直接将树转化为dfs序,压缩成线性的,同时每个结点维护两个时间戳,一个是进去的时间戳,一个是出来的时间戳。
那么对于树上的路径比如从\(x\)到\(y\),若\(LCA(x,y)\)为其中之一,那么两个的路径在dfs序中的体现就为\(in[x]->in[y]\);否则就为\(out[x]->in[y]\)。
这样写的话也需要一个数组来记录当前结点是否被算入答案中,每到一个位置也要将相应的状态取反。这里注意第二种情况lca也不会算上,所以也要单独考虑一下lca。
既然有了树上莫队,也有树上带修改莫队,好吧,其实原理都是差不多的。
看一个例题:
这基本上就是莫队算法的集大成者了。对答案的影响很好计算,维护一种颜色出现的次数就行了。
主要就是代码,我写了两种,一种是dfs序的,一种是树上分块的。
dfs序
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5 + 5;
int n, m, qq, block;
ll w[N], c[N], in[N], out[N], v[N];
vector <int> g[N] ;
struct Query{
int l, r, id, k;
ll ans ;
}q[N];
struct Upd{
int x, y, last;
}upd[N];
bool cmp_id(Query A, Query B) {
return A.id < B.id ;
}
bool cmp(Query A, Query B) {
if((A.l - 1) / block == (B.l - 1) / block && (A.r - 1) / block == (B.r - 1) / block) return A.k < B.k;
if((A.l - 1) / block == (B.l - 1) / block) return (A.r - 1) / block < (B.r - 1) / block;
return (A.l - 1) / block < (B.l - 1) / block ;
}
int dfn;
ll a[2 * N], f[N][22], deep[N], pre[N];
void dfs(int u, int fa) {
in[u] = ++dfn;
a[dfn] = u ;
deep[u] = deep[fa] + 1;
for(auto v : g[u]) {
if(v == fa) continue ;
f[v][0] = u;
for(int i = 1; i <= 17; i++) f[v][i] = f[f[v][i - 1]][i - 1] ;
dfs(v, u);
}
out[u] = ++dfn;
a[dfn] = u;
}
int LCA(int x, int y) {
if(deep[x] < deep[y]) swap(x, y) ;
for(int i = 17; i >= 0; i--)
if(deep[f[x][i]] >= deep[y]) x = f[x][i] ;
if(x == y) return x;
for(int i = 17; i >= 0; i--)
if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i] ;
return f[x][0] ;
}
ll ans ;
int l, r, t, qnum, num;
bool vis[2 * N];
ll cnt[N] ;
void update(int u) {
int col = c[u] ;
if(vis[u]) ans -= 1ll * w[cnt[col]--] * v[col] ;
else ans += 1ll * w[++cnt[col]] * v[col] ;
vis[u] ^= 1;
}
void update_t(int T, int sign) {
int u = upd[T].x, col = upd[T].y;
if(sign == -1) col = upd[T].last;
if(vis[u]) {
update(u);
c[u] = col;
update(u);
} else c[u] = col;
}
int main() {
ios::sync_with_stdio(false); cin.tie(0);
cin >> n >> m >> qq;
block = pow(n, 0.666666) ;
for(int i = 1; i <= m; i++) cin >> v[i] ;
for(int i = 1; i <= n; i++) cin >> w[i] ;
for(int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
g[u].push_back(v) ;
g[v].push_back(u) ;
}
dfs(1, 0);
for(int i = 1; i <= n; i++) cin >> c[i], pre[i] = c[i];
for(int i = 1; i <= qq; i++) {
int op, x, y;
cin >> op >> x >> y;
if(op == 1) {
if(in[x] > in[y]) swap(x, y) ;
int lca = LCA(x, y) ;
q[++num].r = in[y];
q[num].k = qnum;
q[num].id = num;
if(lca == x) q[num].l = in[x] ;
else q[num].l = out[x];
} else {
upd[++qnum].x = x;
upd[qnum].y = y;
//pre[qnum] = (qnum == 1 ? c[x] : upd[qnum - 1].y) ;
upd[qnum].last = pre[x];
pre[x] = y;
}
}
sort(q + 1, q + num + 1, cmp) ;
l = 1, r = 0, t = 0;
for(int i = 1; i <= num; i++) {
for(; t < q[i].k; t++) update_t(t + 1, 1) ;
for(; t > q[i].k; t--) update_t(t, -1) ;
for(; r < q[i].r; r++) update(a[r + 1]) ;
for(; r > q[i].r; r--) update(a[r]) ;
for(; l < q[i].l; l++) update(a[l]) ;
for(; l > q[i].l; l--) update(a[l - 1]) ;
int lca = LCA(a[l], a[r]) ;
if(lca != a[l] && lca != a[r]) {
update(lca) ;
q[i].ans = ans ;
update(lca) ;
} else q[i].ans = ans ;
}
sort(q + 1, q + num + 1, cmp_id) ;
for(int i = 1; i <= num; i++)
cout << q[i].ans << '\n' ;
return 0;
}
树上分块
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5 + 5;
int n, m, qq, block;
int w[N], c[N], in[N], v[N];
int dfn;
int f[N][22], deep[N], pre[N];
int sta[N], bel[N];
int top, tot;
vector <int> g[N] ;
struct Query{
int l, r, id, k;
ll ans ;
}q[N];
struct Upd{
int x, y, last;
}upd[N];
bool cmp_id(Query A, Query B) {
return A.id < B.id ;
}
bool cmp(Query A, Query B) {
if(bel[A.l] == bel[B.l] && bel[A.r] == bel[B.r]) return A.k < B.k;
if(bel[A.l] == bel[B.l]) return bel[A.r] < bel[B.r] ;
return bel[A.l] < bel[B.l] ;
}
void dfs(int u, int fa) {
in[u] = ++dfn;
deep[u] = deep[fa] + 1;
int tmp = top ;
for(auto v : g[u]) {
if(v == fa) continue ;
f[v][0] = u;
for(int i = 1; i <= 16; i++) f[v][i] = f[f[v][i - 1]][i - 1] ;
dfs(v, u);
if(top - tmp >= block) {
tot++;
while(top > tmp) bel[sta[top--]] = tot;
}
}
sta[++top] = u ;
}
int LCA(int x, int y) {
if(deep[x] < deep[y]) swap(x, y) ;
for(int i = 16; i >= 0; i--)
if(deep[f[x][i]] >= deep[y]) x = f[x][i] ;
if(x == y) return x;
for(int i = 16; i >= 0; i--)
if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i] ;
return f[x][0] ;
}
ll ans ;
int l, r, t, qnum, num;
bool vis[N];
ll cnt[N] ;
void modify(int x) {
int col = c[x] ;
if(vis[x]) ans -= 1ll * w[cnt[col]--] * v[col] ;
else ans += 1ll * w[++cnt[col]] * v[col] ;
vis[x] ^= 1;
}
void update(int x, int y) {
while(x != y) {
if(deep[x] >= deep[y]) modify(x), x = f[x][0] ;
else modify(y), y = f[y][0] ;
}
}
void change(int x, int col) {
if(vis[x]) {
modify(x) ;
c[x] = col ;
modify(x) ;
} else c[x] = col ;
}
int main() {
ios::sync_with_stdio(false); cin.tie(0);
cin >> n >> m >> qq;
block = pow(n, 0.666666) ;
for(int i = 1; i <= m; i++) cin >> v[i] ;
for(int i = 1; i <= n; i++) cin >> w[i] ;
for(int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
g[u].push_back(v) ;
g[v].push_back(u) ;
}
dfs(1, 0);
for(int i = 1; i <= n; i++) cin >> c[i], pre[i] = c[i] ;
for(int i = 1; i <= qq; i++) {
int op, x, y;
cin >> op >> x >> y;
if(op == 1) {
if(in[x] > in[y]) swap(x, y) ;
q[++num].id = num;q[num].l = x;
q[num].r = y;q[num].k = qnum;
} else {
upd[++qnum].x = x;upd[qnum].y = y;
upd[qnum].last = pre[x] ;
pre[x] = upd[qnum].y ;
}
}
sort(q + 1, q + num + 1, cmp) ;
l = q[1].l, r = q[1].r, t = 0;
update(l, r);
for(int i = 1; i <= num; i++) {
for(;t < q[i].k; t++) change(upd[t + 1].x, upd[t + 1].y) ;
for(;t > q[i].k; t--) change(upd[t].x, upd[t].last) ;
update(l, q[i].l) ;
update(r, q[i].r) ;
int lca = LCA(q[i].l, q[i].r) ;
modify(lca) ;
q[q[i].id].ans = ans ;
modify(lca) ;
l = q[i].l, r = q[i].r ;
}
for(int i = 1; i <= num; i++) cout << q[i].ans << '\n' ;
return 0;
}
再看看这个题:
CF375D Tree and Queries
这里询问的是出现次数大于等于k的颜色有多少种,看似比较棘手。实际上我们维护一个数组\(sum[i]\),表示大于等于\(i\)的颜色有多少种就行了。这个稍微想想还是比较清楚的。
代码如下:
Code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 2e5 + 5;
int n, m, block;
int c[2 * N], a[2 * N], cnt[N];
int ans ;
vector <int> g[N];
struct Query{
int l, r, k, id, ans;
}q[N];
bool cmp(Query A, Query B) {
if((A.l - 1) / block == (B.l - 1) / block) return A.r < B.r;
return (A.l - 1) / block < (B.l - 1) / block ;
}
bool cmp_id(Query A, Query B) {
return A.id < B.id ;
}
int in[N], out[N] ;
int dfn, tot;
bool vis[2 * N], has[2 * N];
int sum[N] ;
void update(int pos, int val) {
int col = c[a[pos]] ;
if(val == 1) {
if(vis[a[pos]]) return ;
vis[a[pos]] = 1;
sum[++cnt[col]]++;
} else {
if(!vis[a[pos]]) return ;
vis[a[pos]] = 0;
sum[cnt[col]--]--;
}
}
void dfs(int u, int fa) {
in[u] = ++dfn;
a[dfn] = u ;
int t = dfn;
for(auto v : g[u]) {
if(v == fa) continue ;
dfs(v, u) ;
}
out[u] = ++dfn;
a[dfn] = u ;
}
int main() {
ios::sync_with_stdio(false); cin.tie(0);
cin >> n >> m;
block = sqrt(n) ;
for(int i = 1; i <= n; i++) cin >> c[i] ;
for(int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1, 0);
for(int i = 1; i <= m; i++) {
int v;
cin >> v >> q[i].k;
q[i].id = i;
q[i].l = in[v] ; q[i].r = out[v] ;
}
sort(q + 1, q + m + 1, cmp);
int l = 1, r = 0;
for(int i = 1; i <= m; i++) {
int k = q[i].k ;
for(; r < q[i].r; r++) update(r + 1, 1) ;
for(; r > q[i].r; r--) update(r, -1) ;
for(; l < q[i].l; l++) update(l, -1) ;
for(; l > q[i].l; l--) update(l - 1, 1) ;
q[i].ans = sum[k] ;
}
sort(q + 1, q + m + 1, cmp_id) ;
for(int i = 1; i <= m; i++)
cout << q[i].ans << '\n' ;
return 0;
}
最后再来看一道例题:
BZOJ3289:Mato的文件管理
学过莫队之后是不是感觉很简单?
每次区间转移用树状数组维护信息即可。
代码如下:
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <cmath>
using namespace std;
typedef long long ll;
const int N = 50005;
int c[N], a[N], b[N];
int l, r ;
int n, block;
struct Query{
int l, r, id ;
ll ans ;
}q[N];
bool cmp(Query A, Query B) {
if((A.l - 1) / block == (B.l - 1) / block) return A.r < B.r;
return (A.l - 1) / block < (B.l - 1) / block ;
}
int lowbit(int x) {
return x & (-x) ;
}
void add(int x, int val) {
for(int i = x; i < N; i += lowbit(i)) c[i] += val;
}
ll query(int x) {
ll ans = 0;
for(int i = x; i > 0; i -= lowbit(i)) ans += c[i];
return ans ;
}
ll ans ;
void update(int x, int v, int sign) {
if(sign == 1) {
if(v == 1) {
add(a[x], 1) ;
int sum = query(a[x]) ;
ans += r - l + 2 - sum ;
} else {
int sum = query(a[x]) ;
ans -= (r - l + 1 - sum) ;
add(a[x], -1) ;
}
} else {
if(v == 1) {
int sum = query(a[x] - 1) ;
ans += sum;
add(a[x], 1) ;
} else {
add(a[x], -1) ;
int sum = query(a[x] - 1) ;
ans -= sum ;
}
}
}
int main() {
ios::sync_with_stdio(false); cin.tie(0);
cin >> n;
block = sqrt(n) ;
for(int i = 1; i <= n; i++) cin >> a[i], b[i] = a[i];
sort(b + 1, b + n + 1);
int D = unique(b + 1, b + n + 1) - b - 1;
for(int i = 1; i <= n; i++) a[i] = lower_bound(b + 1, b + D + 1, a[i]) - b;
int Q;
cin >> Q;
for(int i = 1; i <= Q; i++) {
int l, r;
cin >> l >> r;
q[i].l = l; q[i].r = r; q[i].id = i;
}
sort(q + 1, q + Q + 1, cmp) ;
l = 1, r = 0;
for(int i = 1; i <= Q; i++) {
for(; r < q[i].r; r++) update(r + 1, 1, 1) ;
for(; r > q[i].r; r--) update(r, -1, 1) ;
for(; l < q[i].l; l++) update(l, -1, -1) ;
for(; l > q[i].l; l--) update(l - 1, 1, -1) ;
q[q[i].id].ans = ans ;
}
for(int i = 1; i <= Q; i++)
cout << q[i].ans << '\n' ;
return 0;
}
重要的是自信,一旦有了自信,人就会赢得一切。