BZOJ #4358 Solution / 类继承 + 回滚莫队
BZOJ #4358 permu Solution / 类继承 + 回滚莫队
本题解纯属自己闲得 egg 疼写出来的神必玩意,如果你是正常的 OIer,你没有必要学习这个东西。
0. 省流
当时我在校队分块题单里面看到了这道题。一眼回滚莫队。
由于每次写莫队我都要写四个函数,我这个懒人十分难受。
于是我突发奇想,我能不能偷个懒,省去回滚莫队的块内还原步骤,改为 \(O(1)\) 的整体还原?
还真可以。
于是就有了这篇题解。
1. 虚函数
虚函数是 C++ 多态的重要特性。
通过虚函数,我们可以实现通过基类访问派生类对应函数。
不过对于这道题来说,即使使用了虚函数,也与派生类同名覆盖没什么区别。
其声明形式如下:
virtual int fstv(int x); // 虚函数
virtual int fstv(int x) = 0; // 纯虚函数
virtual int fstv(int x)
{
// foobar
} // 虚函数定义
本题代码中使用纯虚函数,即上述第二种定义方式,因为基类并不会被直接用到。
2. \(O(1)\) 还原的数组
实际上,这种神奇的数组是不存在的。因为,所谓的 \(O(1)\) 还原是通过还原懒标记实现的,并且具有极大的常数。
其定义如下。
class relarray_base
{
protected:
int vis[N], clc = 1, arr[N];
public:
int &operator[](int x)
{
if (vis[x] ^ clc)
vis[x] = clc, arr[x] = fstv(x);
return arr[x];
}
void clear()
{
clc++;
}
virtual int fstv(int x) = 0;
};
其中,vis
数组的意义是对应的 arr
数组元素被访问时最大的 clc
值。如果碰到了新的还原点,那么按照 fstv(x)
的返回值重新赋值。
在 clear
函数中,有且仅有一个操作,即 clc++
,创建新还原点。
对于本题,我创建了三个以 relarray_base
为基的派生类。
1. relarray_f
/ 并查集
class relarray_f : public relarray_base
{
int fstv(int x)
{
return x;
}
};
该派生类定义了 fstv
函数,使并查集内的 \(f\) 数组回到初始时的不连通状态。
2. relarray_s
/ 连通块权值
class relarray_s : public relarray_base
{
int fstv(int x)
{
return 0;
}
};
该派生类定义了 fstv
函数,使并查集内每个连通块的权值初始化为 \(0\)。
3. relarray_cp
/ 复制数组 + 省略回滚
template <typename _Tp>
class relarray_cp : public relarray_base
{
protected:
_Tp &ptr;
public:
relarray_cp(_Tp &x) : ptr(x) {}
int fstv(int x)
{
return ptr[x];
}
};
该数组的构造函数接受一个额外参数 \(x\) ,即复制源。
同样定义了 fstv
函数,用以在执行 clear
后将原数组的值赋值为 ptr
数组中的对应值。
有了上述三个派生类后,解决本问题的码量将会大幅减少。
3. 题解
显然,在一般的回滚莫队中,在区间左边界所在块切换时,我们需要 \(O(n)\) 的计算量恢复额外数组;在一次询问中,我们需要两次 \(O(\sqrt n)\) 计算,分别用来临时扩展左区间以及还原至块右端点。
但是,在上面三个派生类的情况下,恢复数组只需要 \(O(1)\) 的计算量。
对于临时扩展左区间以及还原至块右端点的操作,我们可以改为 \(O(1)\) 的复制不带左边界块信息 + \(O(\sqrt n)\) 的扩展左区间。
处理值域连续块只需要在并查集上查询连通块权值最大值即可。
时间复杂度 \(O(n\sqrt n)\) 。
4. 代码
#include <iostream>
#include <cmath>
#include <algorithm>
using namespace std;
const int N = 5e4 + 10;
int n, m, a[N], sq, blk[N], res[N], mx, tmx;
class relarray_base
{
protected:
int vis[N], clc = 1, arr[N];
public:
int &operator[](int x)
{
if (vis[x] ^ clc)
vis[x] = clc, arr[x] = fstv(x);
return arr[x];
}
void clear()
{
clc++;
}
virtual int fstv(int x) = 0;
};
class relarray_f : public relarray_base
{
int fstv(int x)
{
return x;
}
};
class relarray_s : public relarray_base
{
int fstv(int x)
{
return 0;
}
};
template <typename _Tp>
class relarray_cp : public relarray_base
{
protected:
_Tp &ptr;
public:
relarray_cp(_Tp &x) : ptr(x) {}
int fstv(int x)
{
return ptr[x];
}
};
using rlf = relarray_f;
using rls = relarray_s;
rlf f, tf; // single block
rls siz, tsiz; // single block
relarray_cp<rlf> cf(f); // left expand
relarray_cp<rls> csiz(siz); // left expand
template <typename _Tp>
inline int find(_Tp &f, int x)
{
return x == f[x] ? x : f[x] = find(f, f[x]);
}
template <typename _Tp, typename _Tp2>
inline void merge(_Tp &f, _Tp2 &siz, int x, int y)
{
x = find(f, x), y = find(f, y);
if (x == y)
return;
f[y] = x;
siz[x] += siz[y];
siz[y] = 0;
}
struct qry
{
int l, r, id;
bool operator<(const qry &x) const
{
if (blk[l] ^ blk[x.l])
return blk[l] < blk[x.l];
if (r ^ x.r)
return r < x.r;
return id < x.id;
}
} qr[N];
inline int sg_blk(int l, int r)
{
tf.clear();
tsiz.clear();
int res = 0;
for (int i = l; i <= r; i++)
{
tsiz[find(tf, a[i])]++;
if (tsiz[find(tf, a[i] + 1)])
merge(tf, tsiz, a[i], a[i] + 1);
if (tsiz[find(tf, a[i] - 1)])
merge(tf, tsiz, a[i], a[i] - 1);
res = max(res, tsiz[find(tf, a[i])]);
}
return res;
}
inline void rexpand(int &x, int t)
{
while (x < t)
{
x++;
siz[find(f, a[x])]++;
if (siz[find(f, a[x] + 1)])
merge(f, siz, a[x], a[x] + 1);
if (siz[find(f, a[x] - 1)])
merge(f, siz, a[x], a[x] - 1);
mx = max(mx, siz[find(f, a[x])]);
}
}
inline void lexpand(int x, int t)
{
tmx = mx;
cf.clear();
csiz.clear();
while (x > t)
{
x--;
csiz[find(cf, a[x])]++;
if (csiz[find(cf, a[x] + 1)])
merge(cf, csiz, a[x], a[x] + 1);
if (csiz[find(cf, a[x] - 1)])
merge(cf, csiz, a[x], a[x] - 1);
tmx = max(tmx, csiz[find(cf, a[x])]);
}
}
int main()
{
scanf("%d%d", &n, &m);
sq = sqrt(n);
for (int i = 1; i <= n; i++)
{
scanf("%d", a + i);
blk[i] = (i - 1) / sq;
}
for (int i = 1; i <= m; i++)
{
scanf("%d%d", &qr[i].l, &qr[i].r);
qr[i].id = i;
}
sort(qr + 1, qr + m + 1);
for (int i = 1, lb = sq + 1, rb = sq; i <= m; i++)
{
if (blk[qr[i].l] ^ blk[qr[i - 1].l])
{
f.clear();
siz.clear();
rb = blk[qr[i].l] * sq + sq;
lb = rb + 1;
mx = 0;
}
if (blk[qr[i].l] == blk[qr[i].r])
{
res[qr[i].id] = sg_blk(qr[i].l, qr[i].r);
continue;
}
rexpand(rb, qr[i].r);
lexpand(lb, qr[i].l);
res[qr[i].id] = tmx;
}
for (int i = 1; i <= m; i++)
{
printf("%d\n", res[i]);
}
}