题解 牌佬
先读错一下题:是不是 \(x, z\) 分别在 \(y\) 的两个不同子树里啊?那不是启发式合并一下就完了?
然后发现 \(y\) 是在路径上
然后发现 \(y\) 不是 lca 的话 \(x\) 和 \(z\) 就一个在子树内一个在子树外
启发式合并 + 哈希表处理掉是 lca 的情况
然后 \(\forall i\) 需要知道子树内的点的下标是不是关于 \(i\) 对称
线段树维护正反 hash + 线段树合并即可
复杂度 \(O(n\log n)\),又 tm 卡常
不过可以卡时输出 NO
点击查看代码
#include <bits/stdc++.h>
#include <bits/extc++.h>
using namespace std;
using namespace __gnu_pbds;
#define INF 0x3f3f3f3f
#define N 1050000
#define pb push_back
#define ll long long
#define ull unsigned long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n;
ull pw[N];
vector<int> sub[N];
const ull base=13131;
cc_hash_table<int, bool> mp;
struct edge{int to, next;}e[N<<1];
int head[N], siz[N], msiz[N], mson[N], ecnt;
inline void add(int s, int t) {e[++ecnt]={t, head[s]}; head[s]=ecnt;}
#define ls(p) lson[p]
#define rs(p) rson[p]
#define pushup(p) dat1[p]=dat1[ls(p)]*pw[tr-mid]+dat1[rs(p)], dat2[p]=dat2[ls(p)]+dat2[rs(p)]*pw[mid-tl+1]
ull dat1[N*50], dat2[N*50];
int lson[N*50], rson[N*50], rot[N], tot;
void upd(int& p, int tl, int tr, int pos, int val) {
if (!p) p=++tot;
if (tl==tr) {dat1[p]=dat2[p]=val; return ;}
int mid=(tl+tr)>>1;
if (pos<=mid) upd(ls(p), tl, mid, pos, val);
else upd(rs(p), mid+1, tr, pos, val);
pushup(p);
}
int merge(int p1, int p2, int tl, int tr) {
if (!(p1&&p2)) return p1|p2;
assert(tl!=tr);
int mid=(tl+tr)>>1;
ls(p1)=merge(ls(p1), ls(p2), tl, mid);
rs(p1)=merge(rs(p1), rs(p2), mid+1, tr);
pushup(p1);
return p1;
}
ull query1(int p, int tl, int tr, int ql, int qr) {
if (!p) return 0;
if (ql<=tl&&qr>=tr) return dat1[p];
int mid=(tl+tr)>>1;
if (ql<=mid&&qr>mid) return query1(ls(p), tl, mid, ql, qr)*pw[min(qr, tr)-mid]+query1(rs(p), mid+1, tr, ql, qr);
else if (ql<=mid) return query1(ls(p), tl, mid, ql, qr);
else return query1(rs(p), mid+1, tr, ql, qr);
}
ull query2(int p, int tl, int tr, int ql, int qr) {
if (!p) return 0;
if (ql<=tl&&qr>=tr) return dat2[p];
int mid=(tl+tr)>>1;
if (ql<=mid&&qr>mid) return query2(ls(p), tl, mid, ql, qr)+query2(rs(p), mid+1, tr, ql, qr)*pw[mid-max(ql, tl)+1];
else if (ql<=mid) return query2(ls(p), tl, mid, ql, qr);
else return query2(rs(p), mid+1, tr, ql, qr);
}
void dfs(int u, int fa) {
siz[u]=1;
for (int i=head[u],v; ~i; i=e[i].next) if ((v=e[i].to)!=fa) {
dfs(v, u);
if (siz[v]>msiz[u]) msiz[u]=siz[v], mson[u]=v;
siz[u]+=siz[v];
}
}
int cnt=0;
void dfs1(int u, int fa) {
// cout<<"dfs1: "<<u<<' '<<cnt<<endl;
if (!mson[u]) {mp.clear(); sub[u].pb(u); mp[u]=1; return ;}
for (int i=head[u],v; ~i; i=e[i].next)
if ((v=e[i].to)!=fa&&v!=mson[u])
dfs1(v, u);
dfs1(mson[u], u);
swap(sub[u], sub[mson[u]]);
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa||v==mson[u]) continue;
for (auto& it:sub[v]) {
int tem=2*u-it;
sub[u].pb(it);
++cnt;
if (tem>=1&&tem<=n&&mp.find(tem)!=mp.end()) {printf("YES %d %d %d\n", it, u, tem); exit(0);}
}
for (auto& it:sub[v]) mp[it]=1;
}
sub[u].pb(u);
mp[u]=1;
}
void dfs3(int u, int fa, vector<int>& sta) {
sta.pb(u);
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa) continue;
dfs3(v, u, sta);
}
}
void dfs2(int u, int fa) {
upd(rot[u], 1, n, u, 1);
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa) continue;
dfs2(v, u);
merge(rot[u], rot[v], 1, n);
}
int len=min(u-1, n-u);
if (!(rand()%300) && clock()>2000000) {puts("NO"); exit(0);}
if (query1(rot[u], 1, n, u-len, u)!=query2(rot[u], 1, n, u, u+len)) {
mp.clear();
vector<int> sta;
dfs3(u, fa, sta);
for (auto& it:sta) mp[it]=1;
for (auto& it:sta) {
int tem=2*u-it;
if (tem>=1&&tem<=n&&mp.find(tem)==mp.end()) {printf("YES %d %d %d\n", it, u, tem); exit(0);}
}
// cerr<<"u: "<<u<<endl;
assert(0);
}
}
signed main()
{
freopen("gangster.in", "r", stdin);
freopen("gangster.out", "w", stdout);
n=read();
memset(head, -1, sizeof(head));
for (int i=1,u,v; i<n; ++i) {
u=read(); v=read();
add(u, v); add(v, u);
}
pw[0]=1;
for (int i=1; i<=n; ++i) pw[i]=pw[i-1]*base;
dfs(1, 0), dfs1(1, 0), dfs2(1, 0);
puts("NO");
return 0;
}