Luogu P3292 [SCOI2016]幸运数字

看到异或最值,显然想到线性基。
用树上倍增的方法,维护当前点\(x\)到倍增父节点\(fa[x][i]\)这条路径上的线性基,在倍增的时候暴力合并即可。
注意这个线性基的倍增数组是没有包括最后一个点的信息的,需要特殊处理。然后就搞完了。
时间复杂度\(O(n*log_n*log_v+q*log_n*log_v)\)

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

#define R register
#define LL long long

const int MAXN=2e4+10;
const int MAXQ=2e5+10;

int n,q;
LL val[MAXN];
int lg[MAXN];

int head[MAXN],cnt;
struct edge { int to,next; } e[MAXN<<1];
inline void add(int x,int y) { e[++cnt]={y,head[x]}; head[x]=cnt; }

class Basic {
	#define MB 63
	public:
		LL p[MB+1];
		Basic() { memset(p,0,sizeof(p)); }
		inline void clear() { memset(p,0,sizeof(p)); }
		inline void ins(LL x) {
			for(R int i=MB;i>=0;i--)
				if(x&(1LL<<i)) {
					if(!p[i]) { p[i]=x; return ;}
					else x^=p[i];
				}
		}
		inline LL ask() {
			LL ans=0;
			for(R int i=MB;i>=0;i--) 
				if((ans^p[i])>ans) ans^=p[i];
			return ans;
		}
};

inline Basic operator + (Basic x,Basic &y) {
	for(R int i=MB;i>=0;i--)
		if(y.p[i]) x.ins(y.p[i]);
	return x;
}

int dep[MAXN],fa[MAXN][15];
Basic bas[MAXN][15];

inline void dfs(int x,int fx) {
	dep[x]=dep[fx]+1;
	fa[x][0]=fx; bas[x][0].ins(val[x]);
	for(R int i=1;i<=lg[dep[x]];i++) {
		fa[x][i]=fa[fa[x][i-1]][i-1];
		bas[x][i]=bas[x][i]+bas[x][i-1];
		bas[x][i]=bas[x][i]+bas[fa[x][i-1]][i-1];
	}
	for(R int i=head[x];i;i=e[i].next) {
		int y=e[i].to;
		if(y==fx) continue;
		dfs(y,x);
	}
}

Basic Ans;

inline LL ask(int x,int y) {
	if(dep[x]<dep[y]) swap(x,y);
	Ans.clear();
	while(dep[x]>dep[y]) {
		Ans=Ans+bas[x][lg[dep[x]-dep[y]]];
		x=fa[x][lg[dep[x]-dep[y]]];
	}
	if(x==y) {
		Ans.ins(val[x]);
		return Ans.ask();
	}
	for(R int i=lg[dep[x]];i>=0;i--)
		if(fa[x][i]!=fa[y][i]) {
			Ans=Ans+bas[x][i];
			Ans=Ans+bas[y][i];
			x=fa[x][i];
			y=fa[y][i];
		}
	Ans.ins(val[x]);
	Ans.ins(val[y]);
	Ans.ins(val[fa[x][0]]);
	return Ans.ask();
}

inline void Init() {
	scanf("%d%d",&n,&q);
	for(R int i=1;i<=n;i++) scanf("%lld",&val[i]);
	for(R int i=1;i<n;i++ ) {
		int x,y; scanf("%d%d",&x,&y);
		add(x,y); add(y,x);
	}
	for(R int i=2;i<=n;i++) lg[i]=lg[i>>1]+1;
	dfs(1,0);
}

inline void Solve() {
	while(q--) {
		int x,y;
		scanf("%d%d",&x,&y);
		printf("%lld\n",ask(x,y));
	}
}

int main() {
	freopen("a.in","r",stdin);
	freopen("a.out","w",stdout);
	Init();
	Solve();
	return 0;
}
posted @ 2020-05-03 09:23  HN-wrp  阅读(253)  评论(0编辑  收藏  举报