hdu4670 Cube number on a tree 点分治

 这次写不容斥的版本,WA了好几次,又改成容斥的,还是没过,一怒之下把所有的int改成longlong就过了。。。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<map>
#define REP(i,a,b) for(int i=a;i<=b;i++)
#define MS0(a) memset(a,0,sizeof(a))

using namespace std;

typedef long long ll;
const int maxn=1000100;
const int INF=1e9+10;

ll N,K;
ll p[maxn];
ll val[maxn];
ll u,v;
ll e[maxn],tot;
ll first[maxn],Next[maxn];
bool vis[maxn];
ll rt,balance;
map<ll,ll> id;
ll d[maxn],dn;
ll s[maxn];

void Init()
{
    tot=0;
    memset(first,-1,sizeof(first));
}

void addedge(ll u,ll v)
{
    e[++tot]=v;
    Next[tot]=first[u];
    first[u]=tot;
}

ll qpow(ll n,ll k)
{
    ll res=1;
    while(k){
        if(k&1) res*=n;
        n*=n;
        k>>=1;
    }
    return res;
}

ll toCube(ll x)
{
    ll res=0;
    REP(i,0,K-1){
        ll cur=0;
        while(x%p[i]==0){
            cur++;
            x/=p[i];
        }
        res+=qpow(3,i)*(cur%3);
    }
    return res;
}

ll add3(ll a,ll b)
{
    ll c=0,x=0,y=0;
    ll t=1;
    REP(i,0,K-1){
        x=a%3;a/=3;
        y=b%3;b/=3;
        c+=((x+y)%3)*t;
        t*=3;
    }
    return c;
}

ll cut3(ll a,ll b)
{
    ll c=0,x=0,y=0;
    ll t=1;
    REP(i,0,K-1){
        x=a%3;a/=3;
        y=b%3;b/=3;
        c+=((x-y+3)%3)*t;
        t*=3;
    }
    return c;
}

void dfs_d(ll u,ll f,ll dep)
{
    d[++dn]=u;
    s[u]=dep;
    for(int i=first[u];~i;i=Next[i]){
        int v=e[i];
        if(v==f||vis[v]) continue;
        dfs_d(v,u,add3(dep,val[v]));
    }
}

ll get_rt(ll u,ll f,int sz)
{
    ll cnt=1,balance1=0;
    for(int i=first[u];~i;i=Next[i]){
        int v=e[i];
        if(v==f||vis[v]) continue;
        ll tmp=get_rt(v,u,sz);
        cnt+=tmp;
        balance1=max(balance1,tmp);
    }
    balance1=max(balance1,sz-cnt);
    if(balance1<balance){
        balance=balance1;
        rt=u;
    }
    return cnt;
}

ll solve(int u)
{
    rt=u;balance=INF;
    ll sz=get_rt(u,0,N);
    rt=u;balance=INF;
    get_rt(u,0,sz);
    u=rt;
    vis[u]=1;
    ll res=0;
    id.clear();
    s[u]=val[u];
    id[s[u]]++;
    if(val[u]==0) res++;
    for(ll i=first[u];~i;i=Next[i]){
        int v=e[i];
        if(vis[v]) continue;
        dn=0;
        dfs_d(v,u,add3(val[u],val[v]));
        REP(j,1,dn){
            ll idx=d[j],x=s[idx];
            ll y=cut3(val[u],x);
            res+=id[y];
        }
        REP(j,1,dn){
            ll idx=d[j],x=s[idx];
            id[x]++;
        }
    }
    for(int i=first[u];~i;i=Next[i]){
        int v=e[i];
        if(vis[v]) continue;
        res+=solve(v);
    }
    return res;
}

int main()
{
    //freopen("in.txt","r",stdin);
    while(~scanf("%d%d",&N,&K)){
        REP(i,0,K-1) scanf("%I64d",&p[i]);
        REP(i,1,N){
            ll x;scanf("%I64d",&x);
            val[i]=toCube(x);
            //cout<<"i="<<i<<" val[i]="<<val[i]<<" x="<<x<<endl;
        }
        Init();
        REP(i,1,N-1){
            scanf("%d%d",&u,&v);
            addedge(u,v);
            addedge(v,u);
        }
        MS0(vis);
        printf("%I64d\n",solve(1));
    }
    return 0;
}
/**
5
3 2 3 5
2500 200 9 270000 27
4 2
3 5
2 5
4 1

2
2 3 5
9 3
1 2

6
3 2 3 5
10 10 10 10 10 10
1 2
2 3
3 4
4 5
5 6

6
3 2 3 5
216 10 10 10 25 5
1 2
2 3
3 4
4 5
5 6

12
3 2 3 5
1 3 5 3 3 9 1 5 5 2 4 2
1 4
4 2
4 5
4 6
4 7
2 3
5 8
5 9
6 10
6 11
7 12

5
3  2 3 5
1 1 1 1 1
1 2
1 3
1 4
1 5



*/
View Code

 

posted @ 2016-03-15 23:17  __560  阅读(305)  评论(0编辑  收藏  举报