hdu5593--ZYB's Tree(树形dp)

 

问题描述
ZYB有一颗N个节点的树,现在他希望你对于每一个点,求出离每个点距离不超过KK的点的个数.

两个点(x,y)在树上的距离定义为两个点树上最短路径经过的边数,

为了节约读入和输出的时间,我们采用如下方式进行读入输出:

读入:读入两个数A,B,令fai​​为节点i的父亲,fa1​​=0;fai​​=(Ai+B)%(i1)+1,i[2,N] .

输出:输出时只需输出N个点的答案的xor和即可。
输入描述
第一行一个整数TT表示数据组数。

接下来每组数据:

 一行四个正整数N,K,A,B.

 最终数据中只有两组N100000。

1T5,5000001N500000,1K10,10000001A,B1000000
输出描述
T行每行一个整数表示答案.
输入样例
1
3 1 1 1
输出样例
3

 

re了好多好多次。注意数据范围A*i+B是会超int的!

距离一个点距离为k的值就用距离这个点距离为1的点更新。可能是该点的儿子,也可能是该点的父亲。

儿子节点直接算没有什么需要注意的地方。然后 与父亲结点距离为k-1的节点的数量减去该节点贡献的部分 就是该节点经过父亲结点符合要求结点的数量。

父亲结点有个坑。。。首先,需要先算父亲结点再算儿子节点,其次更新时要倒着更新。。。具体看代码吧。。。

#include <bits/stdc++.h>

using namespace std;

typedef unsigned int ut;
typedef long long ll;

const int N = 500005;
const int K = 12;

struct Edge {
    int to, next;
} edge[N];
int head[N];
int cnt_edge;
void add_edge(int u, int v)
{
    edge[cnt_edge].to = v;
    edge[cnt_edge].next = head[u];
    head[u] = cnt_edge++;
}

int dp[N][K];
int fa[N];

int n, k;

void dfs(int u)
{
    dp[u][0] = 1;
    for (int i = head[u]; i != -1; i = edge[i].next)
    {
        int v = edge[i].to;
        dfs(v);
        for (int j = 1; j <= k; ++j)
        {
            dp[u][j] += dp[v][j - 1];
        }
    }
}

void solve(int u)
{
    if (u != 1)
    {
        for (int j = k; j >= 2; --j)//这里需要注意 更新的方向!
        {
            dp[u][j] += dp[ fa[u] ][j - 1] - dp[u][j - 2];
        }
        dp[u][1]++;
    }

    for (int i = head[u]; i != -1; i = edge[i].next)
    {
        int v = edge[i].to;
        solve(v);
    }
}

int main()
{
    int t;
    scanf("%d", &t);
    while (t--)
    {
        int a, b;

        scanf("%d%d%d%d", &n, &k, &a, &b);
        cnt_edge = 0;
        memset(dp, 0, sizeof dp);
        memset(head, -1, sizeof head);
        for (int i = 2; i <= n; ++i)
        {
            int f = ((ll)a * i + b) % (i - 1) + 1;
            add_edge(f, i);
            fa[i] = f;
        }
        dfs(1);
        solve(1);
        int ans = 0;
        for (int i = 1; i <= n; ++i)
        {
            int tmp = 0;
            for (int j = 0; j <= k; ++j)
            {
                tmp += dp[i][j];
            }
            ans ^= tmp;
        }
        printf("%d\n", ans);
    }
    return 0;
}

  

posted @ 2015-12-05 23:31  我不吃饼干呀  阅读(276)  评论(0编辑  收藏  举报