SP11414 COT3 - Combat on a tree

题目链接

首先尝试从初态走到 next 状态, 设根节点为 u,某个白色节点为 v, 将 u->v 染黑, 想象把这条路径缩成一点,它的子树们就是一个 next 状态:虽然是多个游戏的和。

一个自然的想法是 dfs, 然后暴力搞, 复杂度是 O(n^2)。

正解是对这个暴力的优化, 直接维护子游戏的集合,支持创建集合、全局异或、合并集合和查询mex, 重点在于实现这些操作的数据结构——0/1Trie。

最后, 这题的方方面面都与 CSP2019day1T3 的对于 O(n^3) 暴力的优化的思想特别相像, 比如暴力,比如正解,比如输出答案;但这个思想本质上也是个简单的东西:集合增添极少元素,某些东西可以快速维护。


若 Trie 要支持全局异或的操作, 就不可避免地要维护每个节点的 “深度”, 可以采取递归写法。

#include<bits/stdc++.h>
using namespace std;

const int N = 100003;

int n, v[N];
int ct, hd[N], nt[N*2+1], vr[N*2+1];
void ad(int u,int v) {
	nt[++ct]=hd[u],hd[u]=ct; vr[ct]=v;
}

int rt[N];
struct Trie{
	int ls[N*100], rs[N*100], cov[N*100], tag[N*100], tot;
	void ins(int &u,int x,int d) {
		u = ++tot;
		if(d==-1) {
			cov[u]=1; return;
		}
		if((x>>d)&1) ins(rs[u],x,d-1);
		else ins(ls[u],x,d-1);
	}
	void put(int u,int x,int d) {
		if(d==-1) return;
		if((x>>d)&1) swap(ls[u],rs[u]);
		tag[u] ^= x;
	}
	void ps_d(int u,int d) {
		if(tag[u]) {
			if(ls[u]) put(ls[u],tag[u],d-1);
			if(rs[u]) put(rs[u],tag[u],d-1);
			tag[u] = 0;
		}
	}
	int meg(int u,int v,int d) {
		if(u&&v) {
			if(d==-1) {
				cov[u]|=cov[v]; return u;
			}
			ps_d(u,d), ps_d(v,d);
			ls[u] = meg(ls[u],ls[v],d-1);
			rs[u] = meg(rs[u],rs[v],d-1);
			cov[u] = (cov[ls[u]] && cov[rs[u]]);
			return u;
		} else return u|v;
	}
	int g_mex(int u,int d) {
		if(d==-1 || !u) return 0;
		if(cov[ls[u]]) return g_mex(rs[u],d-1) ^ (1<<d);
		else return g_mex(ls[u],d-1);
	}
} T;

int sg[N];
void dfs(int x,int fa) {
	int ssg = 0;
	for(int i=hd[x],y=vr[i]; i; i=nt[i],y=vr[i])
		if(y^fa) {
			dfs(y,x);
			ssg ^= sg[y];
		}
	if(!v[x]) T.ins(rt[x],ssg,17);
	for(int i=hd[x],y=vr[i]; i; i=nt[i],y=vr[i])
		if(y^fa) {
			T.put(rt[y],ssg^sg[y],17);
			rt[x] = T.meg(rt[x],rt[y],17);
		}
	sg[x] = T.g_mex(rt[x],17);
}

int ans[N], m;
void g_ans(int x,int fa,int ssg) {
	for(int i=hd[x],y=vr[i]; i; i=nt[i],y=vr[i])
		if(y^fa) ssg ^= sg[y];
	if(!v[x]&&ssg==0) ans[++m] = x;
	for(int i=hd[x],y=vr[i]; i; i=nt[i],y=vr[i])
		if(y^fa) g_ans(y,x,ssg^sg[y]);
}

int main()
{
	scanf("%d",&n);
	for(int i=1;i<=n;++i) scanf("%d",&v[i]);
	for(int i=1,u,v; i<n; ++i) {
		scanf("%d%d",&u,&v);
		ad(u,v), ad(v,u);
	}
	dfs(1,0);
	g_ans(1,0,0);
	if(m==0) puts("-1");
	else
	{
		sort(ans+1,ans+1+m);
		for(int i=1;i<=m;++i) cout << ans[i] << '\n';
	}
	return 0;
}
posted @ 2020-11-13 08:36  xwmwr  阅读(122)  评论(0编辑  收藏  举报