P7206 [COCI2019-2020#3] Lampice
题意
一棵树,每个结点有一个小写字母,求最长的回文路径。
思路
路径问题我们可以想到用点分治来求解。
我们点分治的时候,维护每个点到重心的正反路径的 hash 值。
设点到重心的 hash 值为 \(A\),重心到点的 hash 值为 \(B\)。
假设实现我们有两个长度分别为 \(L_1,L_2\) 的路径,它们的 LCA 是当前分治到的重心。
这两条路径要组成回文路径当且仅当:
\[A_1+B_2\times 27^{L_1}=A_2+B_1\times 27^{L_2}
\]
移项,同除 \(27^{L_1+L_2}\):
\[A_1/27^{L_1+L_2}-B_1/ 27^{L_1}=A_2/27^{L_1+L_2}-B_2/ 27^{L_2}
\]
这里 \(L_1+L_2\) 就是回文路径的长度。
我们知道,一个路径如果是回文路径,那么去掉头尾两个结点它仍然是回文路径。
我们考虑奇偶长度分开来二分,然后每次做点分治,并记录下 \(A/27^{len}-B/ 27^{L}\)。
时间复杂度为 \(O(n\log^3n)\),需要卡常。
一个卡常技巧就是,如果 分治到的联通块大小 小于 固定的回文路径长度 时,就直接返回。
记录时将 map 换成 unordered_map 实测快不少。
代码
#include<iostream>
#include<fstream>
#include<algorithm>
#include<cmath>
#include<cstdlib>
#include<cstring>
#include<queue>
#include<unordered_map>
#include<set>
#include<bitset>
#define LL long long
using namespace std;
namespace MOD
{
const int mod = 1e9 + 7;
inline int add(int a, int b) {return a + b >= mod ? a + b - mod : a + b;}
inline int mul(int a, int b) {return 1ll * a * b % mod;}
inline int sub(int a, int b) {return a - b < 0 ? a - b + mod : a - b;}
inline int fast_pow(int a, int b)
{
int re = 1;
while(b)
{
if(b & 1) re = mul(re, a);
a = mul(a, a);
b >>= 1;
}
return re;
}
inline int inv(int a) {return fast_pow(a, mod - 2);}
} using namespace MOD;
int n, ans = 1, p[50005], invp[50005]; char c[50005];
struct Node
{
int to, nxt;
}r[100005]; int he[50005];
inline void Edge_add(int u, int v)
{
static int cnt = 0;
r[++cnt] = (Node){v, he[u]};
he[u] = cnt;
}
bitset<50005> vis;
int tot, sz[50005], Mx[50005], rt;
void getrt(int now, int fa)
{
sz[now] = 1, Mx[now] = 0;
for(int i = he[now]; i; i = r[i].nxt)
{
int to = r[i].to;
if(vis[to] || to == fa) continue;
getrt(to, now);
sz[now] += sz[to];
Mx[now] = max(Mx[now], sz[to]);
}
Mx[now] = max(Mx[now], tot - sz[now]);
if(Mx[now] < Mx[rt]) rt = now;
}
int h1[50005], h2[50005], dep[50005], q[50005], ta, nlen;
int L, R; bool chk;
#define pii pair<int, int>
#define mp make_pair
struct hash_pair {
template <class T1, class T2>
size_t operator()(const pair<T1, T2>& p) const
{
auto hash1 = hash<T1>{}(p.first);
auto hash2 = hash<T2>{}(p.second);
return hash1 ^ hash2;
}
};
unordered_map<pii, bool, hash_pair> f;
inline void gethash(int now, int fa)
{
for(int i = he[now]; i; i = r[i].nxt)
{
int to = r[i].to;
if(vis[to] || to == fa) continue;
h1[to] = add(mul(h1[now], 27), c[to] - 'a' + 1), h2[to] = add(h2[now], mul(c[to]- 'a' + 1, p[dep[to] = dep[now] + 1]));
q[++ta] = to;
gethash(to, now);
}
}
#define h(H1, H2, sum, len) sub(mul(H1, invp[sum]), mul(H2, invp[len]))
inline void calc(int now)
{
f.clear();
h1[now] = h2[now] = c[now] - 'a' + 1, dep[now] = 0;
f[mp(0, 0)] = 1;
for(int i = he[now]; i; i = r[i].nxt)
{
int to = r[i].to;
if(vis[to]) continue;
h1[to] = add(c[to] - 'a' + 1, h1[now] * 27); h2[to] = add(h2[now], (c[to] - 'a' + 1) * 27), dep[to] = 1;
q[ta = 1] = to;
gethash(to, now);
for(int j = 1; j <= ta; j++)
{
int val = h(h1[q[j]], h2[q[j]], nlen, dep[q[j]] + 1);
if(f[mp(val, nlen - dep[q[j]] - 1)])
{chk = true; return;}
}
for(int j = 1; j <= ta; j++)
{
int H1 = sub(h1[q[j]], mul(h1[now], p[dep[q[j]]]));
int H2 = mul(sub(h2[q[j]], h1[now]), invp[1]);
int val = h(H1, H2, nlen, dep[q[j]]);
f[mp(val, dep[q[j]])] = 1;
}
}
}
void solve(int now)
{
vis[now] = 1, calc(now); if(chk) return;
int rp = tot - sz[now];
for(int i = he[now]; i; i = r[i].nxt)
{
int to = r[i].to;
if(vis[to]) continue;
tot = Mx[rt = 0] = (sz[to] > sz[now] ? rp : sz[to]);
if(tot < nlen) continue;
getrt(to, 0);
solve(rt);
}
}
signed main()
{
#ifndef ONLINE_JUDGE
freopen("test.in", "r", stdin);
freopen("test.out", "w", stdout);
#endif
scanf("%d%s", &n, c + 1);
p[0] = invp[0] = 1; for(int i = 1; i <= n; i++) p[i] = mul(p[i - 1], 27);
invp[n] = inv(p[n]); for(int i = n - 1; i >= 1; i--) invp[i] = mul(invp[i + 1], 27);
for(int i = 1; i < n; i++)
{
int u, v;
scanf("%d%d", &u, &v);
Edge_add(u, v), Edge_add(v, u);
}
L = 1, R = (n + 3) >> 1;
while(L + 1 < R)
{
int mid = (L + R) >> 1;
vis = chk = 0; nlen = (mid << 1) - 1;
tot = Mx[rt = 0] = n;
getrt(1, 0);
solve(rt);
if(chk) L = mid;
else R = mid;
}
ans = max(ans, (L << 1) - 1);
L = ans >> 1, R = (n + 2) >> 1;
while(L + 1 < R)
{
int mid = (L + R) >> 1;
vis = chk = 0; nlen = mid << 1;
tot = Mx[rt = 0] = n;
getrt(1, 0);
solve(rt);
if(chk) L = mid;
else R = mid;
}
ans = max(ans, L << 1);
printf("%d", ans);
return 0;
}