[HDU5909]Tree Cutting
题目大意:
给定一棵树,每个节点上有权值,定义一个树上联通块\(S\)的权值为\(V_1\oplus V_2\oplus...\oplus V_k(\forall i\in S)\),问权值为\(x(x\in[0,m)\)的联通块个数
考虑dp,设\(f[i][j]\)表示以\(i\)为根,子数中权值为\(j\)的联通块个数,合并显然就是两个\(f\)异或卷积一下,FWT优化即可
/*program from Wolfycz*/
#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define inf 0x7f7f7f7f
using namespace std;
typedef long long ll;
typedef unsigned int ui;
typedef unsigned long long ull;
inline char gc(){
static char buf[1000000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,1000000,stdin),p1==p2)?EOF:*p1++;
}
inline int frd(){
int x=0,f=1;char ch=gc();
for (;ch<'0'||ch>'9';ch=gc()) if (ch=='-') f=-1;
for (;ch>='0'&&ch<='9';ch=gc()) x=(x<<1)+(x<<3)+ch-'0';
return x*f;
}
inline int read(){
int x=0,f=1;char ch=getchar();
for (;ch<'0'||ch>'9';ch=getchar()) if (ch=='-') f=-1;
for (;ch>='0'&&ch<='9';ch=getchar()) x=(x<<1)+(x<<3)+ch-'0';
return x*f;
}
inline void print(int x){
if (x<0) putchar('-'),x=-x;
if (x>9) print(x/10);
putchar(x%10+'0');
}
const int N=1e3,M=1<<10,Mod=1e9+7,inv=5e8+4;
void div(int &x){x=1ll*x*inv%Mod;}
void FWT(int *a,int n,int flag){
for (int i=2;i<=n;i<<=1){
for (int j=0;j<n;j+=i){
for (int k=0;k<i>>1;k++){
int x=a[j+k],y=a[j+k+(i>>1)];
a[j+k]=(x+y)%Mod,a[j+k+(i>>1)]=(x-y+Mod)%Mod;
if (!~flag) div(a[j+k]),div(a[j+k+(i>>1)]);
}
}
}
}
int pre[(N<<1)+10],now[N+10],child[(N<<1)+10];
int f[N+10][M+10],Ans[M+10];
int n,m,tot;
void join(int x,int y){pre[++tot]=now[x],now[x]=tot,child[tot]=y;}
void insert(int x,int y){join(x,y),join(y,x);}
void dfs(int x,int fa){
FWT(f[x],m,1);
for (int p=now[x],son=child[p];p;p=pre[p],son=child[p]){
if (son==fa) continue;
dfs(son,x);
for (int i=0;i<m;i++) f[x][i]=1ll*f[x][i]*f[son][i]%Mod;
}
FWT(f[x],m,-1);
if (++f[x][0]==Mod) f[x][0]=0;
FWT(f[x],m,1);
}
int main(){
for (int T=read();T;T--){
n=read(),m=read(),tot=0;
memset(f,0,sizeof(f));
memset(now,0,sizeof(now));
memset(Ans,0,sizeof(Ans));
for (int i=1;i<=n;i++) f[i][read()]=1;
for (int i=1;i<n;i++){
int x=read(),y=read();
insert(x,y);
}
dfs(1,0);
for (int i=1;i<=n;i++) FWT(f[i],m,-1);
for (int i=1;i<=n;i++) if (--f[i][0]<0) f[i][0]+=Mod;
for (int i=1;i<=n;i++)
for (int j=0;j<m;j++)
Ans[j]=(Ans[j]+f[i][j])%Mod;
for (int i=0;i<m;i++) printf("%d",Ans[i]),putchar(i==m-1?'\n':' ');
}
return 0;
}