hdu 5909 Tree Cutting(FWT优化DP)

http://acm.hdu.edu.cn/showproblem.php?pid=5909

 

题意:

给出一颗带点权的树,输出有多少联通块的点权异或和=[1,m)

 

dp[x][i] 以x为根的子树中,联通块内一定有x,目前异或和为i 的联通块 个数

dp[x][i] =  dp[x][i] + Σ Σ [j^k==i] dp[x][j]^dp[son][k]

时间复杂度为 n*m*m

用FWT优化到 n*m*logm

 

 

#include<cstdio>
#include<vector>
#include<cstring>
#include<iostream>

using namespace std;

const int mod=1e9+7;

#define N 1001
const int M=(1<<10)+2;

int m;
int val[N];
vector<int>V[N];

int dp[N][M],tmp[M];
int ans[M];

long long inv;

void read(int &x)
{
    x=0; char c=getchar();
    while(!isdigit(c)) c=getchar();
    while(isdigit(c)) { x=x*10+c-'0'; c=getchar(); }
}

void FWT_xor(int *a,int n)
{
    int x,y;
    for(int d=1;d<n;d<<=1)
        for(int m=d<<1,i=0;i<n;i+=m)
            for(int j=0;j<d;++j)
            {
                x=a[i+j]; y=a[i+j+d];
                a[i+j]=(x+y)%mod; a[i+j+d]=(x-y+mod)%mod;
            }
}

void IFWT_xor(int *a,int n)
{
    int x,y;
    for(int d=1;d<n;d<<=1)
        for(int m=d<<1,i=0;i<n;i+=m)
            for(int j=0;j<d;++j)
            {
                x=a[i+j]; y=a[i+j+d];
                a[i+j]=(x+y)*inv%mod; a[i+j+d]=((x-y)*inv%mod+mod)%mod;
            }
}

void solve(int *a,int *b,int n)
{
    FWT_xor(a,n);
    FWT_xor(b,n);
    for(int i=0;i<n;++i) a[i]=(1LL*a[i]*b[i])%mod;
    IFWT_xor(a,n);
}

void dfs(int x,int fa)
{
    dp[x][val[x]]=1;
    int siz=V[x].size(),t;
    for(int i=0;i<siz;++i)
    {
        t=V[x][i];
        if(t==fa) continue;
        dfs(t,x);
        for(int i=0;i<m;++i) tmp[i]=dp[x][i];
        solve(dp[x],dp[t],m);
        for(int i=0;i<m;++i) (dp[x][i]+=tmp[i])%=mod;
    }
    for(int i=0;i<m;++i) (ans[i]+=dp[x][i])%=mod;
}

long long Pow(long long a,int b)
{
    long long res=1;
    for(;b;b>>=1,a=a*a%mod)
        if(b&1) res=res*a%mod;
    return res;
}

int main()
{
    int T;
    int n,u,v;
    read(T);
    inv=Pow(2,mod-2);
    while(T--)
    {
        memset(dp,0,sizeof(dp));
        memset(ans,0,sizeof(ans));
        read(n); read(m);
        for(int i=1;i<=n;++i) read(val[i]);
        for(int i=1;i<=n;++i) V[i].clear();
        for(int i=1;i<n;++i)
        {
            read(u); read(v);
            V[u].push_back(v);
            V[v].push_back(u);
        }
        dfs(1,0);
        for(int i=0;i<m-1;++i) printf("%d ",ans[i]);
        printf("%d\n",ans[m-1]);
    }
}    
posted @ 2018-03-18 18:00  TRTTG  阅读(230)  评论(0编辑  收藏  举报