Evensgn 剪树枝

题目描述

繁华中学有一棵苹果树。苹果树有 n 个节点(也就是苹果),n − 1 条边(也就

是树枝)。调皮的 Evensgn 爬到苹果树上。他发现这棵苹果树上的苹果有两种:一

种是黑苹果,一种是红苹果。Evensgn 想要剪掉 k 条树枝,将整棵树分成 k + 1 个

部分。他想要保证每个部分里面有且仅有一个黑苹果。请问他一共有多少种剪树枝

的方案?

输入

第一行一个数字 n,表示苹果树的节点(苹果)个数。

第二行一共 n − 1 个数字 p0, p1, p2, p3, ..., pn−2,pi 表示第 i + 1 个节点和 pi 节

点之间有一条边。注意,点的编号是 0 到 n − 1。

第三行一共 n 个数字 x0, x1, x2, x3, ..., xn−1。如果 xi 是 1,表示 i 号节点是黑

苹果;如果 xi 是 0,表示 i 号节点是红苹果。

输出

输出一个数字,表示总方案数。答案对 109 + 7 取模。

样例输入

样例输入 260 1 1 0 41 1 0 0 1 0样例输入 3100 1 2 1 4 4 4 0 80 0 0 1 0 1 1 0 0 1

样例输出

样例输出 12样例输出 21样例输出 327
树归
has[x] x与他父亲相连的连通块里有黑苹果的方案数
no[x]  x与他父亲的联通快里没有黑苹果的方案数
初始化叶子节点 黑苹果 has[x]=1; no[x]=1;
红苹果 has[x]=0;no[x]=1;
黑苹果 no[x]=no[son[i][1]]*no[son[i][2]]* * * *   假设x与父亲相连的边被删掉
            has[x]=no[x]  x与父亲相连的边不被删,且保证连通块内只有一个黑苹果
红苹果 has[x]= ∑ has[son[i]]*no[son[j]] 
            no[x]=no[son[i][1]]*no[son[i][2]]* *  *  * *
            no[x]=no[x]+has[x]  x与其父亲的连边可以被删掉
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<queue>
#define maxn 100005
#define mod 1000000007
#define LL long long
using namespace std;
int n;
struct edge
{
    int to,ne;  
}b[maxn*2];
int k=0,head[maxn],a[maxn],fa[maxn];
bool leaf[maxn];
LL has[maxn],no[maxn];
vector <int > v[maxn];
void add(int u,int v)
{
     k++;
     b[k].to=v;b[k].ne=head[u];head[u]=k;
}
queue<int > q;
void dfs(int x)
{
    q.push(x);
    while(!q.empty())
    {
       int z=q.front();q.pop();
       for(int i=head[z];i!=-1;i=b[i].ne)
        if(b[i].to!=fa[z]){
          fa[b[i].to]=z;
          v[z].push_back(b[i].to);
          q.push(b[i].to);
        }
        if(v[z].empty()) leaf[z]=1;               
    }
}
void dp(int x)
{ 
    if(has[x]!=-1&&no[x]!=-1) return ;
    if(leaf[x]){
        if(a[x]){ has[x]=1; no[x]=1; }
        else { has[x]=0; no[x]=1; }
        return ;
    }
    has[x]=no[x]=0;
    no[x]=1; 
    for(int i=0;i<v[x].size();i++){
        dp(v[x][i]);
        no[x]=(no[x]%mod*no[v[x][i]]%mod)%mod;
    }
    if(a[x]) has[x]=(no[x]%mod);
    else{
        LL op;
        for(int i=0;i<v[x].size();i++){
          op=1;
            for(int j=0;j<v[x].size();j++)
            if(j!=i) op=(op*no[v[x][j]])%mod;
          has[x]=(has[x]+(has[v[x][i]]*op)+mod)%mod;
        } 
        no[x]=(has[x]+no[x])%mod;
    }
    return ;
}
int main()
{
    memset(has,-1,sizeof(has));
    memset(no,-1,sizeof(no));
    memset(head,-1,sizeof(head));
    scanf("%d",&n);
    int x;
    for(int i=1;i<n;i++){
        scanf("%d",&x);
        add(i,x);add(x,i);
    }
    dfs(0);
    for(int i=0;i<n;i++)  scanf("%d",&a[i]);
    dp(0);
    printf("%lld",(has[0]+mod)%mod);
    return 0;
}


posted @ 2017-08-10 21:44  HunterxHunterl  阅读(203)  评论(0编辑  收藏  举报