Evensgn 剪树枝
题目描述
繁华中学有一棵苹果树。苹果树有 n 个节点(也就是苹果),n − 1 条边(也就
是树枝)。调皮的 Evensgn 爬到苹果树上。他发现这棵苹果树上的苹果有两种:一
种是黑苹果,一种是红苹果。Evensgn 想要剪掉 k 条树枝,将整棵树分成 k + 1 个
部分。他想要保证每个部分里面有且仅有一个黑苹果。请问他一共有多少种剪树枝
的方案?
输入
第一行一个数字 n,表示苹果树的节点(苹果)个数。
第二行一共 n − 1 个数字 p0, p1, p2, p3, ..., pn−2,pi 表示第 i + 1 个节点和 pi 节
点之间有一条边。注意,点的编号是 0 到 n − 1。
第三行一共 n 个数字 x0, x1, x2, x3, ..., xn−1。如果 xi 是 1,表示 i 号节点是黑
苹果;如果 xi 是 0,表示 i 号节点是红苹果。
输出
输出一个数字,表示总方案数。答案对 109 + 7 取模。
样例输入
样例输入 260 1 1 0 41 1 0 0 1 0样例输入 3100 1 2 1 4 4 4 0 80 0 0 1 0 1 1 0 0 1
样例输出
样例输出 12样例输出 21样例输出 327
树归
has[x] x与他父亲相连的连通块里有黑苹果的方案数
no[x] x与他父亲的联通快里没有黑苹果的方案数
初始化叶子节点 黑苹果 has[x]=1; no[x]=1;
红苹果 has[x]=0;no[x]=1;
黑苹果 no[x]=no[son[i][1]]*no[son[i][2]]* * * * 假设x与父亲相连的边被删掉
has[x]=no[x] x与父亲相连的边不被删,且保证连通块内只有一个黑苹果
红苹果 has[x]= ∑ has[son[i]]*no[son[j]]
no[x]=no[son[i][1]]*no[son[i][2]]* * * * *
no[x]=no[x]+has[x] x与其父亲的连边可以被删掉
#include<iostream> #include<cstdio> #include<cstring> #include<vector> #include<queue> #define maxn 100005 #define mod 1000000007 #define LL long long using namespace std; int n; struct edge { int to,ne; }b[maxn*2]; int k=0,head[maxn],a[maxn],fa[maxn]; bool leaf[maxn]; LL has[maxn],no[maxn]; vector <int > v[maxn]; void add(int u,int v) { k++; b[k].to=v;b[k].ne=head[u];head[u]=k; } queue<int > q; void dfs(int x) { q.push(x); while(!q.empty()) { int z=q.front();q.pop(); for(int i=head[z];i!=-1;i=b[i].ne) if(b[i].to!=fa[z]){ fa[b[i].to]=z; v[z].push_back(b[i].to); q.push(b[i].to); } if(v[z].empty()) leaf[z]=1; } } void dp(int x) { if(has[x]!=-1&&no[x]!=-1) return ; if(leaf[x]){ if(a[x]){ has[x]=1; no[x]=1; } else { has[x]=0; no[x]=1; } return ; } has[x]=no[x]=0; no[x]=1; for(int i=0;i<v[x].size();i++){ dp(v[x][i]); no[x]=(no[x]%mod*no[v[x][i]]%mod)%mod; } if(a[x]) has[x]=(no[x]%mod); else{ LL op; for(int i=0;i<v[x].size();i++){ op=1; for(int j=0;j<v[x].size();j++) if(j!=i) op=(op*no[v[x][j]])%mod; has[x]=(has[x]+(has[v[x][i]]*op)+mod)%mod; } no[x]=(has[x]+no[x])%mod; } return ; } int main() { memset(has,-1,sizeof(has)); memset(no,-1,sizeof(no)); memset(head,-1,sizeof(head)); scanf("%d",&n); int x; for(int i=1;i<n;i++){ scanf("%d",&x); add(i,x);add(x,i); } dfs(0); for(int i=0;i<n;i++) scanf("%d",&a[i]); dp(0); printf("%lld",(has[0]+mod)%mod); return 0; }