luoguP3292 [SCOI2016]幸运数字(点分治做法)
题意
考虑点分治,每次处理过重心的询问(即两点在重心的不同子树中)。
求出每个点到重心的线性基,之后对过重心的询问合并两点线性基求解。
code:
#include<bits/stdc++.h>
using namespace std;
#define pii pair<int,int>
#define mkp make_pair
#define fir first
#define sec second
typedef long long ll;
const int maxn=20010;
const int maxm=200010;
const int inf=1e9;
int n,m,cnt,root,maxsize=inf,trsize;
int head[maxn],size[maxn],check[maxn];
ll a[maxn],ans[maxm];
bool vis[maxn];
vector<pii>ve[maxn];
struct edge{int to,nxt;}e[maxn<<1];
struct Xord
{
ll d[65];
Xord(){memset(d,0,sizeof(d));}
inline void clear(){memset(d,0,sizeof(d));}
inline void insert(ll x)
{
for(int i=61;~i;i--)
{
if(!(x&(1ll<<i)))continue;
if(!d[i]){d[i]=x;return;}
else x^=d[i];
}
}
inline ll query()
{
ll res=0;
for(int i=61;~i;i--)res=max(res,res^d[i]);
return res;
}
}xord[maxn];
inline void add(int u,int v)
{
e[++cnt].nxt=head[u];
head[u]=cnt;
e[cnt].to=v;
}
void getroot(int x,int fa)
{
size[x]=1;
int maxx=0;
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(y==fa||vis[y])continue;
getroot(y,x);size[x]+=size[y];
maxx=max(maxx,size[y]);
}
maxx=max(maxx,trsize-size[x]);
if(maxx<maxsize)maxsize=maxx,root=x;
}
void getxor(int x,int fa)
{
xord[x]=xord[fa];xord[x].insert(a[x]);
for(unsigned int i=0;i<ve[x].size();i++)
{
int y=ve[x][i].fir;
if(check[y]!=root)continue;
Xord tmp=xord[x];
for(int j=0;j<=60;j++)if(xord[y].d[j])tmp.insert(xord[y].d[j]);
ans[ve[x][i].sec]=max(ans[ve[x][i].sec],tmp.query());
}
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(y==fa||vis[y])continue;
getxor(y,x);
}
}
void mark(int x,int fa,int k)
{
check[x]=k;
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(y==fa||vis[y])continue;
mark(y,x,k);
}
}
void solve(int x)
{
//cerr<<x<<endl;
vis[x]=1;check[x]=x;
xord[x].clear();xord[x].insert(a[x]);
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(vis[y])continue;
getxor(y,x);mark(y,x,x);
}
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(vis[y])continue;
maxsize=inf;trsize=size[y];getroot(y,0);
solve(root);
}
}
int main()
{
//freopen("test.in","r",stdin);
//freopen("test.out","w",stdout);
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)scanf("%lld",&a[i]);
for(int i=1;i<n;i++)
{
int u,v;scanf("%d%d",&u,&v);
add(u,v),add(v,u);
}
for(int i=1;i<=m;i++)
{
int x,y;scanf("%d%d",&x,&y);
if(x==y)ans[i]=a[x];
else ve[x].push_back(mkp(y,i)),ve[y].push_back(mkp(x,i));
}
trsize=n;maxsize=inf;getroot(1,0);
solve(root);
for(int i=1;i<=m;i++)printf("%lld\n",ans[i]);
return 0;
}