题解 牌佬

传送门

先读错一下题:是不是 \(x, z\) 分别在 \(y\) 的两个不同子树里啊?那不是启发式合并一下就完了?
然后发现 \(y\) 是在路径上
然后发现 \(y\) 不是 lca 的话 \(x\)\(z\) 就一个在子树内一个在子树外
启发式合并 + 哈希表处理掉是 lca 的情况
然后 \(\forall i\) 需要知道子树内的点的下标是不是关于 \(i\) 对称
线段树维护正反 hash + 线段树合并即可
复杂度 \(O(n\log n)\),又 tm 卡常
不过可以卡时输出 NO

点击查看代码
#include <bits/stdc++.h>
#include <bits/extc++.h>
using namespace std;
using namespace __gnu_pbds;
#define INF 0x3f3f3f3f
#define N 1050000
#define pb push_back
#define ll long long
#define ull unsigned long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n;
ull pw[N];
vector<int> sub[N];
const ull base=13131;
cc_hash_table<int, bool> mp;
struct edge{int to, next;}e[N<<1];
int head[N], siz[N], msiz[N], mson[N], ecnt;
inline void add(int s, int t) {e[++ecnt]={t, head[s]}; head[s]=ecnt;}
#define ls(p) lson[p]
#define rs(p) rson[p]
#define pushup(p) dat1[p]=dat1[ls(p)]*pw[tr-mid]+dat1[rs(p)], dat2[p]=dat2[ls(p)]+dat2[rs(p)]*pw[mid-tl+1]
ull dat1[N*50], dat2[N*50];
int lson[N*50], rson[N*50], rot[N], tot;
void upd(int& p, int tl, int tr, int pos, int val) {
	if (!p) p=++tot;
	if (tl==tr) {dat1[p]=dat2[p]=val; return ;}
	int mid=(tl+tr)>>1;
	if (pos<=mid) upd(ls(p), tl, mid, pos, val);
	else upd(rs(p), mid+1, tr, pos, val);
	pushup(p);
}
int merge(int p1, int p2, int tl, int tr) {
	if (!(p1&&p2)) return p1|p2;
	assert(tl!=tr);
	int mid=(tl+tr)>>1;
	ls(p1)=merge(ls(p1), ls(p2), tl, mid);
	rs(p1)=merge(rs(p1), rs(p2), mid+1, tr);
	pushup(p1);
	return p1;
}
ull query1(int p, int tl, int tr, int ql, int qr) {
	if (!p) return 0;
	if (ql<=tl&&qr>=tr) return dat1[p];
	int mid=(tl+tr)>>1;
	if (ql<=mid&&qr>mid) return query1(ls(p), tl, mid, ql, qr)*pw[min(qr, tr)-mid]+query1(rs(p), mid+1, tr, ql, qr);
	else if (ql<=mid) return query1(ls(p), tl, mid, ql, qr);
	else return query1(rs(p), mid+1, tr, ql, qr);
}
ull query2(int p, int tl, int tr, int ql, int qr) {
	if (!p) return 0;
	if (ql<=tl&&qr>=tr) return dat2[p];
	int mid=(tl+tr)>>1;
	if (ql<=mid&&qr>mid) return query2(ls(p), tl, mid, ql, qr)+query2(rs(p), mid+1, tr, ql, qr)*pw[mid-max(ql, tl)+1];
	else if (ql<=mid) return query2(ls(p), tl, mid, ql, qr);
	else return query2(rs(p), mid+1, tr, ql, qr);
}

void dfs(int u, int fa) {
	siz[u]=1;
	for (int i=head[u],v; ~i; i=e[i].next) if ((v=e[i].to)!=fa) {
		dfs(v, u);
		if (siz[v]>msiz[u]) msiz[u]=siz[v], mson[u]=v;
		siz[u]+=siz[v];
	}
}

int cnt=0;

void dfs1(int u, int fa) {
	// cout<<"dfs1: "<<u<<' '<<cnt<<endl;
	if (!mson[u]) {mp.clear(); sub[u].pb(u); mp[u]=1; return ;}
	for (int i=head[u],v; ~i; i=e[i].next)
		if ((v=e[i].to)!=fa&&v!=mson[u])
			dfs1(v, u);
	dfs1(mson[u], u);
	swap(sub[u], sub[mson[u]]);
	for (int i=head[u],v; ~i; i=e[i].next) {
		v = e[i].to;
		if (v==fa||v==mson[u]) continue;
		for (auto& it:sub[v]) {
			int tem=2*u-it;
			sub[u].pb(it);
			++cnt;
			if (tem>=1&&tem<=n&&mp.find(tem)!=mp.end()) {printf("YES %d %d %d\n", it, u, tem); exit(0);}
		}
		for (auto& it:sub[v]) mp[it]=1;
	}
	sub[u].pb(u);
	mp[u]=1;
}

void dfs3(int u, int fa, vector<int>& sta) {
	sta.pb(u);
	for (int i=head[u],v; ~i; i=e[i].next) {
		v = e[i].to;
		if (v==fa) continue;
		dfs3(v, u, sta);
	}
}

void dfs2(int u, int fa) {
	upd(rot[u], 1, n, u, 1);
	for (int i=head[u],v; ~i; i=e[i].next) {
		v = e[i].to;
		if (v==fa) continue;
		dfs2(v, u);
		merge(rot[u], rot[v], 1, n);
	}
	int len=min(u-1, n-u);
	if (!(rand()%300) && clock()>2000000) {puts("NO"); exit(0);}
	if (query1(rot[u], 1, n, u-len, u)!=query2(rot[u], 1, n, u, u+len)) {
		mp.clear();
		vector<int> sta;
		dfs3(u, fa, sta);
		for (auto& it:sta) mp[it]=1;
		for (auto& it:sta) {
			int tem=2*u-it;
			if (tem>=1&&tem<=n&&mp.find(tem)==mp.end()) {printf("YES %d %d %d\n", it, u, tem); exit(0);}
		}
		// cerr<<"u: "<<u<<endl;
		assert(0);
	}
}

signed main()
{
	freopen("gangster.in", "r", stdin);
	freopen("gangster.out", "w", stdout);

	n=read();
	memset(head, -1, sizeof(head));
	for (int i=1,u,v; i<n; ++i) {
		u=read(); v=read();
		add(u, v); add(v, u);
	}
	pw[0]=1;
	for (int i=1; i<=n; ++i) pw[i]=pw[i-1]*base;
	dfs(1, 0), dfs1(1, 0), dfs2(1, 0);
	puts("NO");
	
	return 0;
}
posted @ 2022-07-20 15:03  Administrator-09  阅读(3)  评论(0编辑  收藏  举报