bzoj4940 [Ynoi2016]这是我自己的发明 莫队+dfs序
题目传送门
https://lydsy.com/JudgeOnline/problem.php?id=4940
题解
对于换根操作,处理方法就很套路了。
首先先假定以 \(1\) 为根做一遍 dfs,那么在 \(rt\) 为根的时候,对于一个点 \(x\),如果 \(rt\) 不在 \(x\) 的以 \(1\) 为根时的子树中,那么 \(x\) 在 \(rt\) 为根时的子树和以 \(1\) 时的子树一样。
如果 \(rt\) 在 \(x\) 的以 \(1\) 为根时的子树中,那么我们求出 \(y\) 表示在以 \(1\) 为根时,\(x\) 的孩子中,子树里面有 \(rt\) 的。那么那么 \(x\) 在 \(rt\) 为根时的子树就是除去 \(y\) 在以 \(1\) 为根时的子树的全部部分。
如果 \(rt = x\),那么显然子树就是整棵树。
然后根据子树的 dfs 序的连续性,我们就可以把原题转化为这样的问题:
给定两个区间 \([l_1, r_1]\) 和 \([l_2, r_2]\),求出两个区间中有多少对点的点权一样。
有了前面一道题 [Snoi2017]一个简单的询问 的经验,我们知道,这个题目的做法是把一个有四个参数的询问拆分成四个有一个参数的询问。
这个的具体做法就参见我对这个题目的题解了。
#include<bits/stdc++.h>
#define fec(i, x, y) (int i = head[x], y = g[i].to; i; i = g[i].ne, y = g[i].to)
#define dbg(...) fprintf(stderr, __VA_ARGS__)
#define File(x) freopen(#x".in", "r", stdin), freopen(#x".out", "w", stdout)
#define fi first
#define se second
#define pb push_back
template<typename A, typename B> inline char smax(A &a, const B &b) {return a < b ? a = b, 1 : 0;}
template<typename A, typename B> inline char smin(A &a, const B &b) {return b < a ? a = b, 1 : 0;}
typedef long long ll; typedef unsigned long long ull; typedef std::pair<int, int> pii;
namespace io {
const int SIZE = (1 << 21) + 1;
char ibuf[SIZE], *iS, *iT, obuf[SIZE], *oS = obuf, *oT = oS + SIZE - 1, c, qu[55]; int f, qr;
// getchar
#define gc() (iS == iT ? (iT = (iS = ibuf) + fread (ibuf, 1, SIZE, stdin), (iS == iT ? EOF : *iS ++)) : *iS ++)
// print the remaining part
inline void flush () {
fwrite (obuf, 1, oS - obuf, stdout);
oS = obuf;
}
// putchar
inline void putc (char x) {
*oS ++ = x;
if (oS == oT) flush ();
}
// input a signed integer
template <class I>
inline void gi (I &x) {
for (f = 1, c = gc(); c < '0' || c > '9'; c = gc()) if (c == '-') f = -1;
for (x = 0; c <= '9' && c >= '0'; c = gc()) x = x * 10 + (c & 15); x *= f;
}
// print a signed integer
template <class I>
inline void print (I &x) {
if (!x) putc ('0'); if (x < 0) putc ('-'), x = -x;
while (x) qu[++ qr] = x % 10 + '0', x /= 10;
while (qr) putc (qu[qr --]);
}
}
#define read io::gi
const int N = 100000 + 7;
const int M = 500000 + 7;
#define bl(x) (((x) - 1) / blo + 1)
int n, m, dfc, Q, ansi, blo;
ll val;
ll ans[M];
int a[N], b[N], cl[N], cr[N];
int dep[N], f[N], siz[N], son[N], top[N], dfn[N], pre[N];
struct Query {
int opt, l, r;
ll *ans;
inline Query() {}
inline Query(const int &opt, const int &l, const int &r, ll *ans) : opt(opt), l(l), r(r), ans(ans) {
if (l > r) std::swap(this->l, this->r);
assert(this->l <= this->r);
}
inline bool operator < (const Query &b) const { return bl(l) == bl(b.l) ? r < b.r : l < b.l; }
} q[M << 4];
struct Edge { int to, ne; } g[N << 1]; int head[N], tot;
inline void addedge(int x, int y) { g[++tot].to = y, g[tot].ne = head[x], head[x] = tot; }
inline void adde(int x, int y) { addedge(x, y), addedge(y, x); }
inline void dfs1(int x, int fa = 0) {
f[x] = fa, dep[x] = dep[fa] + 1, siz[x] = 1;
for fec(i, x, y) if (y != fa) dfs1(y, x), siz[x] += siz[y], siz[y] > siz[son[x]] && (son[x] = y);
}
inline void dfs2(int x, int pa) {
top[x] = pa, dfn[x] = ++dfc, pre[dfc] = x;
if (!son[x]) return; dfs2(son[x], pa);
for fec(i, x, y) if (y != son[x] && y != f[x]) dfs2(y, y);
}
inline int gson(int x, int p) {
int g = 0;
while (top[x] != top[p]) g = top[x], x = f[g];
return x == p ? g : son[p];
}
inline bool intr(int x, int p) { return dfn[x] >= dfn[p] && dfn[x] <= dfn[p] + siz[p] - 1; }
inline void addq(int l1, int r1, int l2, int r2, ll *ans) {
q[++Q] = Query(1, r1, r2, ans);
if (l1 > 1) q[++Q] = Query(-1, l1 - 1, r2, ans);
if (l2 > 1) q[++Q] = Query(-1, l2 - 1, r1, ans);
if (l1 > 1 && l2 > 1) q[++Q] = Query(1, l1 - 1, l2 - 1, ans);
}
inline void lsh() {
memcpy(b, a, sizeof(int) * (n + 1));
std::sort(b + 1, b + n + 1);
int dis = std::unique(b + 1, b + n + 1) - b - 1;
for (int i = 1; i <= n; ++i) a[i] = std::lower_bound(b + 1, b + dis + 1, a[i]) - b;
}
inline void addl(int x) {
val += cr[a[pre[x]]];
++cl[a[pre[x]]];
}
inline void addr(int x) {
val += cl[a[pre[x]]];
++cr[a[pre[x]]];
}
inline void dell(int x) {
val -= cr[a[pre[x]]];
--cl[a[pre[x]]];
}
inline void delr(int x) {
val -= cl[a[pre[x]]];
--cr[a[pre[x]]];
}
inline void work() {
blo = sqrt(n);
std::sort(q + 1, q + Q + 1);
lsh();
int l = 0, r = 0;
for (int i = 1; i <= Q; ++i) {
while (r < q[i].r) addr(++r);
while (l < q[i].l) addl(++l);
while (l > q[i].l) dell(l--);
while (r > q[i].r) delr(r--);
*q[i].ans += q[i].opt * val;
}
for (int i = 1; i <= ansi; ++i) io::print(ans[i]), io::putc('\n');
io::flush();
}
inline void init() {
read(n), read(m);
for (int i = 1; i <= n; ++i) read(a[i]);
int x, y;
for (int i = 1; i < n; ++i) read(x), read(y), adde(x, y);
dfs1(1), dfs2(1, 1);
int rt = 1;
for (int i = 1; i <= m; ++i) {
int opt, x, y;
read(opt);
if (opt == 1) { read(rt); continue; }
++ansi;
read(x), read(y);
if (x == rt) {
if (y == rt) addq(1, n, 1, n, ans + ansi);
else if (!intr(rt, y)) addq(1, n, dfn[y], dfn[y] + siz[y] - 1, ans + ansi);
else {
y = gson(rt, y);
if (dfn[y] > 1) addq(1, n, 1, dfn[y] - 1, ans + ansi);
if (dfn[y] + siz[y] - 1 < n) addq(1, n, dfn[y] + siz[y], n, ans + ansi);
}
}
else if (!intr(rt, x)) {
if (y == rt) addq(dfn[x], dfn[x] + siz[x] - 1, 1, n, ans + ansi);
else if (!intr(rt, y)) addq(dfn[x], dfn[x] + siz[x] - 1, dfn[y], dfn[y] + siz[y] - 1, ans + ansi);
else {
y = gson(rt, y);
if (dfn[y] > 1) addq(dfn[x], dfn[x] + siz[x] - 1, 1, dfn[y] - 1, ans + ansi);
if (dfn[y] + siz[y] - 1 < n) addq(dfn[x], dfn[x] + siz[x] - 1, dfn[y] + siz[y], n, ans + ansi);
}
} else {
x = gson(rt, x);
if (y == rt) {
if (dfn[x] > 1) addq(1, dfn[x] - 1, 1, n, ans + ansi);
if (dfn[x] + siz[x] - 1 < n) addq(dfn[x] + siz[x], n, 1, n, ans + ansi);
} else if (!intr(rt, y)) {
if (dfn[x] > 1) addq(1, dfn[x] - 1, dfn[y], dfn[y] + siz[y] - 1, ans + ansi);
if (dfn[x] + siz[x] - 1 < n) addq(dfn[x] + siz[x], n, dfn[y], dfn[y] + siz[y] - 1, ans + ansi);
} else {
y = gson(rt, y);
if (dfn[x] > 1) {
if (dfn[y] > 1) addq(1, dfn[x] - 1, 1, dfn[y] - 1, ans + ansi);
if (dfn[y] + siz[y] - 1 < n) addq(1, dfn[x] - 1, dfn[y] + siz[y], n, ans + ansi);
}
if (dfn[x] + siz[x] - 1 < n) {
if (dfn[y] > 1) addq(dfn[x] + siz[x], n, 1, dfn[y] - 1, ans + ansi);
if (dfn[y] + siz[y] - 1 < n) addq(dfn[x] + siz[x], n, dfn[y] + siz[y], n, ans + ansi);
}
}
}
}
}
int main() {
#ifdef hzhkk
freopen("hkk.in", "r", stdin);
#endif
init();
work();
// fclose(stdin), fclose(stdout);
return 0;
}