HDU4670 Cube number on a tree 树分治

    人生的第一道树分治,要是早点学我南京赛就不用那么挫了,树分治的思路其实很简单,就是对子树找到一个重心(Centroid),实现重心分解,然后递归的解决分开后的树的子问题,关键是合并,当要合并跨过重心的两棵子树的时候,需要有一个接近O(n)的方法,因为f(n)=kf(n/k)+O(n)解出来才是O(nlogn).在这个题目里其实就是将第一棵子树的集合里的每个元素,判下有没符合条件的,有就加上,然后将子树集合压进大集合,然后继续搞第二棵乃至第n棵.我的过程用了map,合并是nlogn的所以代码速度颇慢,大概6s,题目时限10s,可以改成hash应该会快许多,毕竟用map实在太慢,用vector也可以,具体可以参见挑战程序设计竞赛代码.下面的代码查找重心用了挑战的代码.

#pragma comment(linker, "/STACK:102400000,102400000")
#include<iostream>
#include<cstring>
#include<string>
#include<cstdio>
#include<algorithm>
#include<map>
#include<vector>
#define maxv 50000
#define ll long long
using namespace std;

int n,k;
vector<int> G[maxv+50];
ll val[maxv+50];
ll prime[maxv+50];
ll convert_three(ll v)
{
    ll bas=1;ll res=0;
    for(int i=0;i<k;++i){
        int num=0;
        while(v%prime[i]==0){
            v/=prime[i];
            num++;
        }
        num%=3;res+=num*bas;
        bas*=3;
    }
    return res;
}

ll xor(ll x,ll y)
{
    ll res=0;ll bas=1;
    for(int i=0;i<k;++i){
        res+=((x%3)+(y%3))%3*bas;
        x/=3;y/=3;
        bas*=3;
    }
    return res;
}

ll inv(ll x)
{
    ll res=0;ll bas=1;
    for(int i=0;i<k;++i){
        res+=((3-(x%3))%3)*bas;
        x/=3;
        bas*=3;
    }
    return res;
}

void print(ll x){
    while(x){
        cout<<x%3;
        x/=3;
    }
    cout<<endl;
}

bool centroid[maxv+50];
int ssize[maxv+50];
int ans;

map<ll,int> sta;
map<ll,int>::iterator it;
int compute_ssize(int v,int p)
{
    int c=1;
    for(int i=0;i<G[v].size();++i){
        int w=G[v][i];
        if(w==p||centroid[w]) continue;
        c+=compute_ssize(G[v][i],v);
    }
    ssize[v]=c;
    return c;
}

pair<int,int> search_centroid(int v,int p,int t)
{
    pair<int,int> res=make_pair(INT_MAX,-1);
    int s=1,m=0;
    for(int i=0;i<G[v].size();++i){
        int w=G[v][i];
        if(w==p||centroid[w]) continue;
        res=min(res,search_centroid(w,v,t));
        m=max(m,ssize[w]);
        s+=ssize[w];
    }
    m=max(m,t-s);
    res=min(res,make_pair(m,v));
    return res;
}

void enumerate_mul(int v,int p,ll d,map<ll,int> &ds)
{
    if(ds.count(d)) ds[d]++;
    else ds[d]=1;
    for(int i=0;i<G[v].size();++i){
        int w=G[v][i];
        if(w==p||centroid[w]) continue;
        enumerate_mul(w,v,xor(d,val[w]),ds);
    }
}

void solve(int v)
{
    compute_ssize(v,-1);
    int s=search_centroid(v,-1,ssize[v]).second;
    centroid[s]=true;
    for(int i=0;i<G[s].size();++i){
        if(centroid[G[s][i]]) continue;
        solve(G[s][i]);
    }
    sta.clear();
    sta[val[s]]=1;map<ll,int> tds;
    for(int i=0;i<G[s].size();++i){
        if(centroid[G[s][i]]) continue;
        tds.clear();
        enumerate_mul(G[s][i],s,val[G[s][i]],tds);
        it=tds.begin();
        while(it!=tds.end()){
            ll rev=inv((*it).first);
            if(sta.count(rev)){
                ans+=sta[rev]*(*it).second;
            }
            ++it;
        }
        it=tds.begin();
        while(it!=tds.end()){
            ll  vv=xor((*it).first,val[s]);
            if(sta.count(vv)){
                sta[vv]+=(*it).second;
            }
            else{
                sta[vv]=(*it).second;
            }
            ++it;
        }
    }
    centroid[s]=false;
}

int main()
{
    while(cin>>n>>k){
        ans=0;
        for(int i=0;i<k;++i){
            scanf("%I64d",&prime[i]);
        }
        G[0].clear();
        for(int i=1;i<=n;++i){
            scanf("%I64d",&val[i]);
            val[i]=convert_three(val[i]);
            if(val[i]==0) ans++;
            //print(val[i]);
            G[i].clear();
        }
        int u,v;
        for(int i=0;i<n-1;++i){
            scanf("%d%d",&u,&v);
            G[u].push_back(v);
            G[v].push_back(u);
        }
        memset(centroid,0,sizeof(centroid));
        solve(1);
        printf("%d\n",ans);
    }
    return 0;
}

 

posted @ 2013-11-07 01:49  chanme  阅读(453)  评论(0编辑  收藏  举报