noip模拟赛 收集果子

分析:显然的,树形dp,状态也很好想到:f[i][j]表示以i为根的子树收集到j个果子的方案数.转移的话就相当于是背包问题,每个子节点可以选或不选.如果不选子节点k的话,那么以k为根的子树的边无论断不断都没关系,贡献就是f[i][j] * 2^(size[k]).如果选的话,枚举一下收集到多少个果子,对答案的贡献就是f[i][j - p] * f[k][p].基本的计数原理.

      不过这个转移是O(n^3)的,怎么优化呢?状态定义为这个样子是没法继续优化的,如果把状态的表示改成dfs到第i个点,收集到j个果子的方案数,就能够神奇地做到O(n^2)了.因为dfs是每次先向下递归,然后子节点向上回溯嘛,向下递归的时候就用父节点的状态去更新子节点的状态,向上回溯就用子节点的答案去更新父节点的答案.也就是说:向下走,更新状态;向上走,统计答案.

 60分暴力:

#include <cstdio>
#include <cmath>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;
const long long mod = 1e9 + 7;
typedef long long ll;

ll n, g[1010], k, q[1010], sizee[1010], a[1010], f[1010][1010], head[1010], to[2020], nextt[2020], tot = 1;

void add(ll x, ll y)
{
    to[tot] = y;
    nextt[tot] = head[x];
    head[x] = tot++;
}

void dfs(ll u, ll fa)
{
    f[u][a[u]] = 1;
    sizee[u] = 1;
    for (ll i = head[u]; i; i = nextt[i])
    {
        ll v = to[i];
        if (v == fa)
            continue;
        dfs(v, u);
        sizee[u] += sizee[v];
        for (ll j = 0; j <= k; j++)
        {
            g[j] = q[sizee[v] - 1] * f[u][j] % mod;
            for (ll kk = 0; kk <= j; kk++)
            {
                g[j] += f[v][kk] * f[u][j - kk] % mod;
                g[j] %= mod;
            }
        }
        for (ll j = 0; j <= k; j++)
            f[u][j] = g[j];
    }
}

int main()
{
    scanf("%lld%lld", &n, &k);
    for (ll i = 1; i <= n; i++)
        scanf("%lld", &a[i]);
    for (ll i = 1; i < n; i++)
    {
        ll x, y;
        scanf("%lld%lld", &x, &y);
        add(x, y);
        add(y, x);
    }
    q[0] = 1;
    q[1] = 2;
    for (ll i = 2; i <= n; i++)
        q[i] = q[i - 1] * 2 % mod;
    dfs(1, 0);
    printf("%lld\n", f[1][k]);

    return 0;
}

AC:

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

const int mod = 1e9 + 7;

using namespace std;
typedef long long ll;
ll n, k, a[1010],sizee[1010], q[1010],f[1010][1010], head[1010], to[2020], nextt[2020], tot = 1;

void add(int x, int y)
{
    to[tot] = y;
    nextt[tot] = head[x];
    head[x] = tot++;
}

void dfs(int u, int fa)
{
    sizee[u] = 1;
    for (int i = head[u]; i; i = nextt[i])
    {
        int v = to[i];
        if (v == fa)
            continue;
        for (int j = 0; j + a[v] <= n; j++)
            f[v][j + a[v]] = f[u][j];
        dfs(v, u);
        sizee[u] += sizee[v];
        for (int j = 0; j <= n; j++)
            f[u][j] = (q[sizee[v] - 1] * f[u][j] % mod + f[v][j]) % mod;
    }
}

int main()
{
    scanf("%lld%lld", &n, &k);
    for (int i = 1; i <= n; i++)
        scanf("%lld", &a[i]);
    for (int i = 1; i < n; i++)
    {
        ll x, y;
        scanf("%lld%lld", &x, &y);
        add(x, y);
        add(y, x);
    }
    f[1][a[1]] = 1;
    q[0] = 1;
    for (int i = 1; i <= n; i++)
        q[i] = q[i - 1] * 2 % mod;
    dfs(1, 0);
    printf("%lld\n", f[1][k]);

    return 0;
}

 

posted @ 2017-11-03 18:03  zbtrs  阅读(368)  评论(0编辑  收藏  举报