【LG5018】[NOIP2018pj]对称的二叉树
【LG5018】[NOIP2018pj]对称的二叉树
题面
题解
看到这一题全都是用\(O(nlogn)\)的算法过的
考场上写\(O(n)\)算法的我很不开心
然后就发了此篇题解。。。
首先我们可以像树上莫队一样按照 左-右-根 的顺序将这棵树的欧拉序跑下来,
记下开始访问点\(x\)的\(dfs\)序\(L[x]\),和回溯时的\(dfs\)序\(R[x]\)
再将记录欧拉序的数组记为\(P\)
void dfs(int x) {
P[L[x] = ++cnt] = x;
if (t[x].ch[0]) dfs(t[x].ch[0]);
if (t[x].ch[1]) dfs(t[x].ch[1]);
P[R[x] = ++cnt] = x;
t[x].size = t[t[x].ch[0]].size + t[t[x].ch[1]].size + 1;
}
统计出数组\(P\)的两个哈希值,一个是记录点权(\(hs1[0][x]\)),
另一个是记录当前点是左儿子还是右儿子(\(hs2[0][x]\))
for (int i = 1; i <= cnt; i++) hs1[0][i] = hs1[0][i - 1] * base + t[P[i]].v;
for (int i = 1; i <= cnt; i++) hs2[0][i] = hs2[0][i - 1] * base + get(P[i]);
再将这棵树按照 右-左-根 的顺序将这棵树的另一个欧拉序跑下来(记得清空),
记下开始访问点\(x\)的\(dfs\)序\(rL[x]\),和回溯时的\(dfs\)序\(rR[x]\)
void rdfs(int x) {
P[rL[x] = ++cnt] = x;
if (t[x].ch[1]) rdfs(t[x].ch[1]);
if (t[x].ch[0]) rdfs(t[x].ch[0]);
P[rR[x] = ++cnt] = x;
}
再记录一次统计出数组\(P\)的两个哈希值,一个是记录点权(\(hs1[1][x]\)),
另一个是记录当前点是左儿子还是右儿子(\(hs2[1][x]\))(这时候要取异或一下)
for (int i = 1; i <= cnt; i++) hs1[1][i] = hs1[1][i - 1] * base + t[P[i]].v;
for (int i = 1; i <= cnt; i++) hs2[1][i] = hs2[1][i - 1] * base + (get(P[i]) ^ 1);
其中\(get\)函数:
inline int get(int x) { return t[t[x].fa].ch[1] == x; }
然后我们要怎么判断呢?
先判断左右儿子\(ls\)和\(rs\)的\(size\)是否相等
然后再判断第一遍\(dfs\)左儿子所覆盖的欧拉序内和
第二遍\(dfs\)右儿子所覆盖的欧拉序内两个哈希值相不相等即可
if (getHash(hs1[0], L[ls], R[ls]) != getHash(hs1[1], rL[rs], rR[rs])) continue;
if (getHash(hs2[0], L[ls], R[ls]) != getHash(hs2[1], rL[rs], rR[rs])) continue;
然而常数过大,速度被nlogn吊打
完整代码
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
namespace IO {
const int BUFSIZE = 1 << 20;
char ibuf[BUFSIZE], *is = ibuf, *it = ibuf;
inline char gc() {
if (is == it) it = (is = ibuf) + fread(ibuf, 1, BUFSIZE, stdin);
return *is++;
}
}
inline int gi() {
register int data = 0, w = 1;
register char ch = 0;
while (ch != '-' && (ch > '9' || ch < '0')) ch = IO::gc();
if (ch == '-') w = -1 , ch = IO::gc();
while (ch >= '0' && ch <= '9') data = data * 10 + (ch ^ 48), ch = IO::gc();
return w * data;
}
#define MAX_N 1000005
struct Node { int ch[2], fa, size, v; } t[MAX_N];
inline int get(int x) { return t[t[x].fa].ch[1] == x; }
typedef unsigned long long ull;
const ull base = 100007;
ull pw[MAX_N << 1];
ull hs1[2][MAX_N << 1], hs2[2][MAX_N << 1];
ull getHash(ull *hs, int l, int r) { return hs[r] - hs[l - 1] * pw[r - l + 1]; }
int N, L[MAX_N], R[MAX_N], rL[MAX_N], rR[MAX_N], P[MAX_N << 1], cnt;
void dfs(int x) {
P[L[x] = ++cnt] = x;
if (t[x].ch[0]) dfs(t[x].ch[0]);
if (t[x].ch[1]) dfs(t[x].ch[1]);
P[R[x] = ++cnt] = x;
t[x].size = t[t[x].ch[0]].size + t[t[x].ch[1]].size + 1;
}
void rdfs(int x) {
P[rL[x] = ++cnt] = x;
if (t[x].ch[1]) rdfs(t[x].ch[1]);
if (t[x].ch[0]) rdfs(t[x].ch[0]);
P[rR[x] = ++cnt] = x;
}
int main () {
N = gi(); pw[0] = 1;
for (int i = 1; i <= 2 * N; i++) pw[i] = pw[i - 1] * base;
for (int i = 1; i <= N; i++) t[i].v = gi();
for (int x = 1; x <= N; x++) {
int ls = gi(), rs = gi();
if (ls != -1) t[x].ch[0] = ls, t[ls].fa = x;
if (rs != -1) t[x].ch[1] = rs, t[rs].fa = x;
}
dfs(1);
for (int i = 1; i <= cnt; i++) hs1[0][i] = hs1[0][i - 1] * base + t[P[i]].v;
for (int i = 1; i <= cnt; i++) hs2[0][i] = hs2[0][i - 1] * base + get(P[i]);
cnt = 0; rdfs(1);
for (int i = 1; i <= cnt; i++) hs1[1][i] = hs1[1][i - 1] * base + t[P[i]].v;
for (int i = 1; i <= cnt; i++) hs2[1][i] = hs2[1][i - 1] * base + (get(P[i]) ^ 1);
int ans = 1;
for (int x = 1; x <= N; x++) {
int ls = t[x].ch[0], rs = t[x].ch[1];
if (t[ls].size != t[rs].size) continue;
if (getHash(hs1[0], L[ls], R[ls]) != getHash(hs1[1], rL[rs], rR[rs])) continue;
if (getHash(hs2[0], L[ls], R[ls]) != getHash(hs2[1], rL[rs], rR[rs])) continue;
ans = max(ans, t[x].size);
}
printf("%d\n", ans);
return 0;
}