HDU 5647 DZY Loves Connecting 树形dp

DZY Loves Connecting

题目连接:

http://acm.hdu.edu.cn/showproblem.php?pid=5647

Description

DZY has an unrooted tree consisting of n nodes labeled from 1 to n.

DZY likes connected sets on the tree. A connected set S is a set of nodes, such that every two nodes u,v in S can be connected by a path on the tree, and the path should only contain nodes from S. Obviously, a set consisting of a single node is also considered a connected set.

The size of a connected set is defined by the number of nodes which it contains. DZY wants to know the sum of the sizes of all the connected sets. Can you help him count it?

The answer may be large. Please output modulo 109+7.

Input

First line contains t denoting the number of testcases.
t testcases follow. In each testcase, first line contains n. In lines 2∼n, ith line contains pi, meaning there is an edge between node i and node pi. (1≤pi≤i−1,2≤i≤n)

(n≥1, sum of n in all testcases does not exceed 200000)

Output

Output one line for each testcase, modulo 109+7.

Sample Input

2
1
5
1
2
2
3

Sample Output

1
42

Hint

题意

给你一颗树,然后求这棵树的所有连通块大小的和是多少。

题解:

树形dp

dp[i][0]表示以i为根的连通块个数,显然dp[i][0]=PI(dp[v][0]+1),v是i的孩子。

这个用一个排列组合很容易看出来。

他的ans实质上是在问,每个点出现在多少个集合的和。

那么dp[i][1]表示以第i个点为根的所有连通块的点一共出现了多少次。

现在考虑到v孩子,那么v孩子的父亲x的贡献就增加了:

(dp[v][0]+1)*dp[x][1]+dp[x][0]*dp[v][1]

(dp[v][0]+1)*dp[x][1] 表示 从v节点出来的那些节点新增加的贡献。

dp[x][0]*dp[v][1] 表示 老的贡献又增加了多少。

然后就完了,作为一个树形dp智障……

代码

#pragma comment(linker, "/STACK:102400000,102400000")
#include<bits/stdc++.h>
using namespace std;
const int maxn = 2e5+7;
const int mod = 1e9+7;
vector<int>E[maxn];
long long dp[maxn][2];
int n;
void init()
{
    for(int i=0;i<=n;i++)E[i].clear();
}
void dfs(int x,int fa)
{
    dp[x][0]=dp[x][1]=1;
    for(int i=0;i<E[x].size();i++)
    {
        int v = E[x][i];
        if(v==fa)continue;
        dfs(v,x);
        dp[x][1]=(dp[x][1]*(dp[v][0]+1)+dp[x][0]*dp[v][1])%mod;
        dp[x][0]=(dp[x][0]*(dp[v][0]+1))%mod;
    }
}
void solve()
{
    scanf("%d",&n);
    init();
    for(int i=2;i<=n;i++)
    {
        int x;scanf("%d",&x);
        E[x].push_back(i);
        E[i].push_back(x);
    }
    dfs(1,0);
    long long ans = 0;
    for(int i=1;i<=n;i++)
        ans = (ans+dp[i][1])%mod;
    cout<<ans<<endl;
}
int main()
{
    int t;scanf("%d",&t);
    while(t--)solve();
}
posted @ 2016-03-23 12:43  qscqesze  阅读(361)  评论(0编辑  收藏  举报