noip模拟赛 洗衣

 

 

分析:好神的一道题啊.对每棵树建个图跑一下floyd可以有40分,想要打出正解就得对树有比较深的认识了.

      每次新生成一棵树都是由两棵树i,j拼成的,答案为原来两棵树的答案和+i中每个点到j中每个点的距离和.显然这个距离和不能直接算,涉及到求整体的值,通常考虑每条边的贡献.设i,j两棵树的连接点为p,q,边长为l.l对答案的贡献是i的点数*j的点数*l.i中每个点要走到j里面,必须要先走到p上,i中的每个点对它走过的路径的贡献是j的点数*路径的边权,因为每条路径都被若干个点经过,所以算出i中所有点到p的距离和乘上j的点数就是i中的边对答案的贡献.j中所有边对答案的贡献也是这样算的.

      那么怎么求树i中所有点到t的距离和呢?还是可以把树i拆分成两棵子树x,y.如果p在i的左子树中,那么就是连接点p到t的距离加上连接x,y的边的长度l的和乘上y中的点数加上y中所有点到q的距离和加上x中所有点到t的距离和.这实际上就相当于拆成了3部分:1是x里面的所有点到t去,2是y中的所有点到q去,3是y中所有点聚集到q后一起到t.如果p在i的右子树中,做法类似,递归地算下去.

      那么怎么求p到t的距离呢?还是分两个点所在的树来讨论.如果p,t在同一子树,就在同一子树内递归算,如果在不同子树就先到连接点,然后再走过去,和之前的做法差不多.

      因为求得的东西要经常被用到,所以需要记录状态的,可是m≤60,点的个数可能多达2^60个,这该怎么记录......其实需要记录的只有连接点的状态,但是连接点的编号可能很大,也会爆空间,一个很神的技巧就是在map中保存状态orz.开long long就能存下了.

      丧心病狂的一道题,求两棵树任意两点之间的距离和的做法在树形dp中比较常见,需要记一下.黑科技map可以记录不多,但是表示起来很大的状态.

 

#include <map>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

typedef long long ll;

const int mod = 1e9 + 7;

struct node
{
    ll p, p1, p2;
    node(){}
    node(ll a, ll b, ll c)
    {
        p = a;
        if (b < c)
        {
            p1 = b;
            p2 = c;
        }
        else
        {
            p1 = c;
            p2 = b;
        }
    }
    bool operator < (const node & a) const {
        if (p != a.p)
            return p < a.p;
        if (p1 != a.p1)
            return p1 < a.p1;
        return p2 < a.p2;
    }
};

int n;
ll id1[170], id2[170], num1[170], num2[170], l[170], sizee[170];
ll ans[170];

map <pair< ll, ll > , ll > m;
map <node, ll> m2;

ll solve2(ll p, ll p1, ll p2)
{
    if (!p)
        return 0;
    if (p1 == p2)
        return 0;
    node temp = node(p, p1, p2);
    if (m2.count(temp))
        return m2[temp];
    if (p1 < sizee[id1[p]])
    {
        if (p2 < sizee[id1[p]])
            m2[temp] = solve2(id1[p], p1, p2);
        else
            m2[temp] = (solve2(id1[p], num1[p], p1) + solve2(id2[p], num2[p], p2 - sizee[id1[p]]) + l[p]) % mod;
    }
    else
    {
        if (p2 < sizee[id1[p]])
            m2[temp] = (solve2(id1[p], num1[p], p2) + solve2(id2[p], num2[p], p1 - sizee[id1[p]]) + l[p]) % mod;
        else
            m2[temp] = solve2(id2[p], p1 - sizee[id1[p]], p2 - sizee[id1[p]]);
    }
    return m2[temp];
}

ll solve(ll p, ll n)
{
    if (p == 0)
        return 0;
    pair <ll, ll> t = make_pair(p, n);
    if (m.count(t))
        return m[t];
    if (n < sizee[id1[p]])
    {
        ll temp1 = (solve2(id1[p], num1[p], n) + l[p]) % mod;
        ll temp2 = temp1 * (sizee[id2[p]] % mod) % mod;
        ll temp3 = (solve(id2[p], num2[p]) + solve(id1[p], n)) % mod;
        m[t] = (temp2 + temp3) % mod;
    }
    else
    {
        ll temp1 = (solve2(id2[p], num2[p], n - sizee[id1[p]]) + l[p]) % mod;
        ll temp2 = temp1 * (sizee[id1[p]] % mod) % mod;
        ll temp3 = (solve(id1[p], num1[p]) + solve(id2[p], n - sizee[id1[p]])) % mod;
        m[t] = (temp2 + temp3) % mod;
    }
    return m[t];
}

int main()
{
    scanf("%d", &n);
    for (int i = 1; i <= n; i++)
        scanf("%lld%lld%lld%lld%lld", &id1[i], &id2[i], &num1[i], &num2[i], &l[i]);
    sizee[0] = 1;
    for (int i = 1; i <= n; i++)
        sizee[i] = sizee[id1[i]] + sizee[id2[i]];
    for (int i = 1; i <= n; i++)
    {
        ll temp1 = solve(id1[i], num1[i]) * (sizee[id2[i]] % mod) % mod;
        ll temp2 = (sizee[id1[i]] % mod )* (sizee[id2[i]] % mod) % mod * l[i] % mod;
        ll temp3 = solve(id2[i], num2[i]) * (sizee[id1[i]] % mod) % mod;
        ans[i] = (((((((temp1 + temp2) % mod) + temp3) % mod) + ans[id1[i]]) % mod) + ans[id2[i]]) % mod;
    }
    for (int i = 1; i <= n; i++)
        printf("%lld\n", ans[i]);
return 0;
}

 

posted @ 2017-11-01 09:58  zbtrs  阅读(177)  评论(0编辑  收藏  举报