CF766 E. Mahmoud and a xor trip
题目传送门:https://codeforces.com/problemset/problem/766/E
题目大意:
给定一棵节点数为\(n\)的树,求树上任意两点之间异或路径值之和,记 \(x,y\) 之间的最短路经过点\(p_1,p_2,...,p_k\),其中\(p_1=x,p_k=y\),则异或路径值\(V=p_1\otimes p_2\otimes...\otimes p_k\)
经典的Tree DP吧
记\(F[x][K][0/1]\)表示所有从\(x\)出发的路径中,第\(K\)位为\(0/1\)的情况总数
之后我们再进行二次换根,对所有节点做根的情况都进行统计即可
/*program from Wolfycz*/
#include<map>
#include<cmath>
#include<cstdio>
#include<vector>
#include<cstring>
#include<iostream>
#include<algorithm>
#define Fi first
#define Se second
#define ll_inf 1e18
#define MK make_pair
#define sqr(x) ((x)*(x))
#define pii pair<int,int>
#define int_inf 0x7f7f7f7f
using namespace std;
typedef long long ll;
typedef unsigned int ui;
typedef unsigned long long ull;
inline char gc(){
static char buf[1000000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,1000000,stdin),p1==p2)?EOF:*p1++;
}
template<typename T>inline T frd(T x){
int f=1; char ch=gc();
for (;ch<'0'||ch>'9';ch=gc()) if (ch=='-') f=-1;
for (;ch>='0'&&ch<='9';ch=gc()) x=(x<<1)+(x<<3)+ch-'0';
return x*f;
}
template<typename T>inline T read(T x){
int f=1; char ch=getchar();
for (;ch<'0'||ch>'9';ch=getchar()) if (ch=='-') f=-1;
for (;ch>='0'&&ch<='9';ch=getchar()) x=(x<<1)+(x<<3)+ch-'0';
return x*f;
}
inline void print(int x){
if (x<0) putchar('-'),x=-x;
if (x>9) print(x/10);
putchar(x%10+'0');
}
const int N=1e5,M=20;
int pre[(N<<1)+10],now[N+10],child[(N<<1)+10],tot;
ll V[N+10],F[N+10][M+10][2];
void join(int x,int y){pre[++tot]=now[x],now[x]=tot,child[tot]=y;}
void insert(int x,int y){join(x,y),join(y,x);}
#define T(x) ((V[x]>>K)&1)
void Dfs(int x,int fa){
for (int K=0;K<=M;K++) F[x][K][T(x)]++;
for (int p=now[x];p;p=pre[p]){
int son=child[p];
if (son==fa) continue;
Dfs(son,x);
for (int K=0;K<=M;K++){
F[x][K][0]+=F[son][K][T(x)^0];
F[x][K][1]+=F[son][K][T(x)^1];
}
}
}
void Rorate(int x,int fa){
if (fa){
for (int K=0;K<=M;K++){
ll delta[2];
delta[0]=F[fa][K][T(x)^0]-F[x][K][T(fa)^T(x)^0];
delta[1]=F[fa][K][T(x)^1]-F[x][K][T(fa)^T(x)^1];
F[x][K][0]+=delta[0];
F[x][K][1]+=delta[1];
}
}
for (int p=now[x];p;p=pre[p]){
int son=child[p];
if (son==fa) continue;
Rorate(son,x);
}
}
#undef T
int main(){
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
int n=read(0);
for (int i=1;i<=n;i++) V[i]=read(0ll);
for (int i=1;i<n;i++){
int x=read(0),y=read(0);
insert(x,y);
}
Dfs(1,0); Rorate(1,0);
ll Ans=0;
for (int K=0;K<=M;K++){
ll res=0;
for (int i=1;i<=n;i++) res+=F[i][K][1];
Ans+=(1ll<<K)*res;
}
for (int i=1;i<=n;i++) Ans+=V[i];
Ans>>=1;
printf("%lld\n",Ans);
return 0;
}