HDU5909Tree Cutting

题目大意

给定一颗树,每个点有点权,问对于每个m,有多少个联通块的权值异或和为m。

题解

解法1:可以考虑树形dp,设dp[u][i]表示以u为根的子树中u必须选,联通块权值异或值为i的联通块个数。

转移是m^2的,用FWT优化为mlogm,总复杂度nmlogm

解法2:考虑加一个限制:给一个根,根必须选。

我们可以考虑在欧拉序上做文章,考虑到一个欧拉序的位置上,下一位置是它的儿子,如果我们选择了儿子节点,就往下一个位置转移,否则就跨过这颗子树,转移到下一次回溯到这个点的位置。

这个过程可以用dfs实现。

然后考虑选定点的过程,可以用点分治优化,复杂度nmlogn。

从运行常数来看,点分治的常数小一些。

代码(FWT)

#include<iostream>
#include<cstdio>
#include<cstring>
#define N 1009
using namespace std;
typedef long long ll;
const int mod=1e9+7;
int dp[N][1<<10],ans[1<<10],head[N],tot,n,m,a[N],inv,tag[1<<10],T;
inline int rd(){
    int x=0;char c=getchar();bool f=0;
    while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    return f?-x:x;
}
inline int power(int x,int y){
    int ans=1;
    while(y){
        if(y&1)ans=1ll*ans*x%mod;x=1ll*x*x%mod;y>>=1;
    }
    return ans;
}
struct edge{int n,to;}e[N<<1];
inline void add(int u,int v){e[++tot].n=head[u];e[tot].to=v;head[u]=tot;}
inline void FWT(int *b,int tag){
//    cout<<"(";
    for(int i=1;i<m;i<<=1)
      for(int j=0;j<m;j+=(i<<1))
        for(int k=0;k<i;++k){
            int x=b[j+k],y=b[i+j+k];
            b[j+k]=(x+y)%mod;b[i+j+k]=(x-y+mod)%mod;
            if(tag)b[j+k]=1ll*b[j+k]*inv%mod,b[i+j+k]=1ll*b[i+j+k]*inv%mod;
        }
//    cout<<")";
}
void dfs(int u,int fa){
    dp[u][a[u]]=1;
    for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa){
         int v=e[i].to;
         dfs(v,u);
         for(int j=0;j<m;++j)tag[j]=dp[u][j];
         FWT(tag,0);FWT(dp[v],0);
         for(int j=0;j<m;++j)tag[j]=1ll*tag[j]*dp[v][j]%mod;
         FWT(tag,1);
         for(int j=0;j<m;++j)(dp[u][j]+=tag[j])%=mod;
    }    
    for(int j=0;j<m;++j)(ans[j]+=dp[u][j])%=mod;
}
inline void unit(){
    memset(ans,0,sizeof(ans));
    memset(dp,0,sizeof(dp));
    memset(head,0,sizeof(head));
    tot=0;
}
signed main(){
    T=rd();inv=power(2,mod-2);
    while(T--){
        n=rd();m=rd();int u,v;unit();
        for(int i=1;i<=n;++i)a[i]=rd();
        for(int i=1;i<n;++i){
            u=rd();v=rd();add(u,v);add(v,u);
        }
        dfs(1,0);
        for(int j=0;j<m-1;++j)printf("%d ",ans[j]);printf("%d\n",ans[m-1]);
    }
    return 0;
}
View Code

代码(点分治)

#include<iostream>
#include<cstdio>
#include<cstring>
#define N 1009
using namespace std;
int size[N],tot,head[N],d[N],a[N],sum,m,n,root,dp[N][1<<10],ans[1<<10],T;
bool vis[N];
const int mod=1e9+7;
inline int rd(){
    int x=0;char c=getchar();bool f=0;
    while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    return f?-x:x;
}
struct edge{
    int n,to;
}e[N<<1];
inline void add(int u,int v){e[++tot].n=head[u];e[tot].to=v;head[u]=tot;}
void getroot(int u,int fa){
    d[u]=0;size[u]=1;
    for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa&&!vis[e[i].to]){
        int v=e[i].to;getroot(v,u);
        size[u]+=size[v];d[u]=max(d[u],size[v]);
    }
    d[u]=max(d[u],sum-size[u]);
    if(d[u]<d[root])root=u;
}
void getsize(int u,int fa){
    size[u]=1;
    for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa&&!vis[e[i].to]){
        int v=e[i].to;getsize(v,u);size[u]+=size[v];
    }
}
inline void MOD(int &x){while(x>=mod)x-=mod;}
void calc(int u,int fa){
    for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa&&!vis[e[i].to]){
        int v=e[i].to;
        for(int j=0;j<m;++j)MOD(dp[v][j^a[v]]+=dp[u][j]);
        calc(v,u);
        for(int j=0;j<m;++j)MOD(dp[u][j]+=dp[v][j]),dp[v][j]=0;
    }
}
void solve(int u){
    vis[u]=1;
    dp[u][a[u]]=1;calc(u,0);
    for(int i=0;i<m;++i)MOD(ans[i]+=dp[u][i]),dp[u][i]=0;
    for(int i=head[u];i;i=e[i].n)if(!vis[e[i].to]){
        int v=e[i].to;
        root=n+1;sum=size[v];
        getroot(v,u);getsize(root,0);
        solve(root);
    }
}
inline void unit(){
    memset(vis,0,sizeof(vis));
    memset(head,0,sizeof(head));tot=0;
} 
int main(){
//    freopen("in","r",stdin);
//    freopen("out","w",stdout);
    T=rd();
    while(T--){
        n=rd();m=rd();unit();int u,v;
        for(int i=1;i<=n;++i)a[i]=rd();
        for(int i=1;i<n;++i){u=rd();v=rd();add(u,v);add(v,u);}
        root=n+1;d[root]=n;sum=n;
        getroot(1,0);getsize(root,0);
        solve(root);
        for(int i=0;i<m-1;++i)printf("%d ",ans[i]),ans[i]=0;
        printf("%d\n",ans[m-1]);ans[m-1]=0;
    }
    return 0;
} 
View Code
posted @ 2019-01-16 16:08  comld  阅读(189)  评论(0编辑  收藏  举报