HDU 4670 Cube number on a tree ( 树的点分治 )

题意 : 给你一棵树 。 树的每一个结点都有一个权值 。 问你有多少条路径权值的乘积是一个全然立方数 。

题目中给了你 K 个素数 ( K <= 30 ) , 全部权值都能分解成这k个素数


思路 : 一个全然立方数的素因子个数都是三的倍数 , 所以我们仅仅要求各个素数的个数即可了 , 而且我们仅仅关心个数对三的余数

所以我们能够用一个 长整形来表示每一个结点到根的各个素因子的个数( 三进制压缩 ) 。

只是由于用位运算会快一点 , 所以我用了四进制。即每两位表示一个素因子的个数 。

中间合并的时候计算刚好相加之后变成0的数,然后用map统计就能够了


注意 : 题目中 pi 的值 , 0 是取不到的(  纠结了好久该怎么处理 = = )


#include <stdio.h>
#include <string.h>
#include <map>
#include <algorithm>
#pragma comment(linker, "/STACK:102400000,102400000")
using namespace std;

#define MAXN 50005
#define INF 0x3f3f3f3f
__int64 prime[35] ;
int n , k ;

struct Tree{  
    struct Edge{  
        int to , nex ;  
        Edge(){}  
        Edge( int _to  , int _nex ) {  
            to = _to ;  
            nex = _nex ;  
        }  
    }edge[MAXN*2] ;  

    int head[MAXN] ;  
    int Index ;  
  
    void init(){  
        memset( head , -1 , sizeof(head) ) ;  
        Index = 0 ;  
    }  
  
    void add( int from , int to ) {  
        edge[Index] = Edge( to , head[from] ) ;  
        head[from] = Index ++ ;  
    }  
} tree ;

inline int get( __int64 val , int pos ) {
    return ( val >> ( pos << 1 ) ) & 3 ;
}

inline void set( __int64 & val , int pos , int k ) {
    val |= ( (__int64)k << ( pos << 1 ) ) ;
}

__int64 add( __int64 a , __int64 b ) {
    __int64 ans = 0 ;
    for( int i = 0 ; i < k ; i ++ ) {
        int aa = get( a , i ) ;
        int bb = get( b , i ) ;
        aa += bb ;
        aa %= 3 ; ;
        set( ans , i , aa ) ;
    }
    return ans ;
}

__int64 fan( __int64 a  ) {
    __int64 ans = 0 ;
    for( int i = 0 ; i < k ; i ++ ) {
        int tmp = get( a , i ) ;
        tmp =  3 - tmp ;
        set( ans , i , tmp % 3 ) ;
    }
    return ans ;
}

bool vis[MAXN] ;
__int64 val[MAXN] ;
__int64 ans ;

// 第i结点为根的子树的最大个数  
int dp[MAXN] ;  
// 第i结点为根的子树的大小  
int sum[MAXN] ;  

int Sum ;  
int Min , Minid ;  

void dfs1( int u , int f ) {  
    sum[u] = 1 ;  
    dp[u] = 0 ;  
    for( int i = tree.head[u] ; ~i ; i = tree.edge[i].nex ) {  
        int ch = tree.edge[i].to ;  
        if( ch == f || vis[ch] ) continue ;  
        dfs1( ch , u  ) ;  
        sum[u] += sum[ch] ;  
        dp[u] = max( dp[u] , sum[ch] ) ;  
    }  
}  
  
void dfs2( int u , int f ) {  
    int M = max( dp[u] , Sum - sum[u] ) ;  
    if( M < Min ) {  
        Min = M ;  
        Minid = u ;  
    }  
    for( int i = tree.head[u] ; ~i ; i = tree.edge[i].nex ) {  
        int ch = tree.edge[i].to ;  
        if( ch == f || vis[ch] ) continue ;  
          
        dfs2( ch , u ) ;  
    }  
}  
  
// 返回树的重心  
int getRoot( int u ) {  
    // 第一遍dfs求出全部点的子树大小。和以孩子为跟的子树个数的最大值  
    dfs1( u , 0 ) ;  
    Sum = sum[u] ;  
    Min = 0x3f3f3f3f ;  
    Minid = -1 ;  
    // 第二次dfs求重心  
    /* 事实上假设整棵树求重心的话,我们一边dfs即可。由于这个是部分树,我们要dfs来确定哪些结点在当前树中 */  
    dfs2( u , 0 ) ;  
    return Minid ;  
}  

int tot ;
__int64 Val[MAXN] ;

void getVal( int u , int fa , __int64 vv ) {
    Val[tot++] = add( vv , val[u] ) ;
    for( int i = tree.head[u] ; ~i ; i = tree.edge[i].nex ) {
        int to = tree.edge[i].to ;
        if( vis[to] == true ) 
            continue ;
        if( to == fa ) continue ;
        getVal( to , u , add( vv , val[u] ) ) ;
    }
}

map<__int64,int> mp ;

__int64 cal( int root , __int64 val , __int64 root_val ) {
    tot = 0 ;
    getVal( root , 0 , val ) ;
    //sort( Val , Val + tot ) ;
    mp.clear() ;
    __int64 ans = 0 ;
    for( int i = 0 ; i < tot ; i ++ ) {
        ans += mp[ add( root_val , fan(Val[i]) ) ] ;
        mp[Val[i]] ++ ;
    }
    return ans ;
}

void solve( int u ) {  
    // 得到当前子树的重心  
    int root = getRoot( u ) ;  
    // 计算以root为根的结果  
    ans += cal( root , 0 , val[root] ) ;  
    if( val[root] == 0 ) ans ++ ;
    vis[root] = true ;  
    for( int i = tree.head[root] ; ~i ; i = tree.edge[i].nex ) {  
        int ch = tree.edge[i].to ;  
        if( vis[ch] ) continue ;  
        ans -= cal( ch , val[root] , val[root] ) ;  
        solve( ch ) ;     
    }  
}

int main(){
    while( scanf( "%d" , &n ) != EOF )  {
        scanf( "%d" , &k ) ;
        for( int i = 0 ; i < k ; i ++ ) {
            scanf( "%d" , &prime[i] ) ;
        }
        tree.init() ;
        memset( vis , false , sizeof(vis) );
        memset( val , 0 , sizeof(val) ) ;
        for( int i = 1 ; i <= n ; i ++ ) {
            __int64 tmp ;
            scanf( "%I64d" , &tmp ) ;
            for( int j = 0 ; j < k ; j ++ ) {
                int cnt = 0 ;
                while( tmp % prime[j] == 0 ) {
                    cnt ++ ;
                    tmp /= prime[j] ;
                }
                cnt %= 3 ;
                set( val[i] , j , cnt ) ;
            }
        }
        for( int i = 0 ; i < n - 1 ; i ++ ) {
            int u , v ;
            scanf( "%d%d" , &u , &v ) ;
            tree.add( u , v ) ;
            tree.add( v , u ) ;
        }
        ans = 0 ;
        solve( 1 ) ;
        printf( "%I64d\n" , ans ) ;
    }
    return 0 ;
}


posted @ 2017-04-13 20:32  yangykaifa  阅读(576)  评论(0编辑  收藏  举报