noip模拟赛 洗衣
分析:好神的一道题啊.对每棵树建个图跑一下floyd可以有40分,想要打出正解就得对树有比较深的认识了.
每次新生成一棵树都是由两棵树i,j拼成的,答案为原来两棵树的答案和+i中每个点到j中每个点的距离和.显然这个距离和不能直接算,涉及到求整体的值,通常考虑每条边的贡献.设i,j两棵树的连接点为p,q,边长为l.l对答案的贡献是i的点数*j的点数*l.i中每个点要走到j里面,必须要先走到p上,i中的每个点对它走过的路径的贡献是j的点数*路径的边权,因为每条路径都被若干个点经过,所以算出i中所有点到p的距离和乘上j的点数就是i中的边对答案的贡献.j中所有边对答案的贡献也是这样算的.
那么怎么求树i中所有点到t的距离和呢?还是可以把树i拆分成两棵子树x,y.如果p在i的左子树中,那么就是连接点p到t的距离加上连接x,y的边的长度l的和乘上y中的点数加上y中所有点到q的距离和加上x中所有点到t的距离和.这实际上就相当于拆成了3部分:1是x里面的所有点到t去,2是y中的所有点到q去,3是y中所有点聚集到q后一起到t.如果p在i的右子树中,做法类似,递归地算下去.
那么怎么求p到t的距离呢?还是分两个点所在的树来讨论.如果p,t在同一子树,就在同一子树内递归算,如果在不同子树就先到连接点,然后再走过去,和之前的做法差不多.
因为求得的东西要经常被用到,所以需要记录状态的,可是m≤60,点的个数可能多达2^60个,这该怎么记录......其实需要记录的只有连接点的状态,但是连接点的编号可能很大,也会爆空间,一个很神的技巧就是在map中保存状态orz.开long long就能存下了.
丧心病狂的一道题,求两棵树任意两点之间的距离和的做法在树形dp中比较常见,需要记一下.黑科技map可以记录不多,但是表示起来很大的状态.
#include <map> #include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; typedef long long ll; const int mod = 1e9 + 7; struct node { ll p, p1, p2; node(){} node(ll a, ll b, ll c) { p = a; if (b < c) { p1 = b; p2 = c; } else { p1 = c; p2 = b; } } bool operator < (const node & a) const { if (p != a.p) return p < a.p; if (p1 != a.p1) return p1 < a.p1; return p2 < a.p2; } }; int n; ll id1[170], id2[170], num1[170], num2[170], l[170], sizee[170]; ll ans[170]; map <pair< ll, ll > , ll > m; map <node, ll> m2; ll solve2(ll p, ll p1, ll p2) { if (!p) return 0; if (p1 == p2) return 0; node temp = node(p, p1, p2); if (m2.count(temp)) return m2[temp]; if (p1 < sizee[id1[p]]) { if (p2 < sizee[id1[p]]) m2[temp] = solve2(id1[p], p1, p2); else m2[temp] = (solve2(id1[p], num1[p], p1) + solve2(id2[p], num2[p], p2 - sizee[id1[p]]) + l[p]) % mod; } else { if (p2 < sizee[id1[p]]) m2[temp] = (solve2(id1[p], num1[p], p2) + solve2(id2[p], num2[p], p1 - sizee[id1[p]]) + l[p]) % mod; else m2[temp] = solve2(id2[p], p1 - sizee[id1[p]], p2 - sizee[id1[p]]); } return m2[temp]; } ll solve(ll p, ll n) { if (p == 0) return 0; pair <ll, ll> t = make_pair(p, n); if (m.count(t)) return m[t]; if (n < sizee[id1[p]]) { ll temp1 = (solve2(id1[p], num1[p], n) + l[p]) % mod; ll temp2 = temp1 * (sizee[id2[p]] % mod) % mod; ll temp3 = (solve(id2[p], num2[p]) + solve(id1[p], n)) % mod; m[t] = (temp2 + temp3) % mod; } else { ll temp1 = (solve2(id2[p], num2[p], n - sizee[id1[p]]) + l[p]) % mod; ll temp2 = temp1 * (sizee[id1[p]] % mod) % mod; ll temp3 = (solve(id1[p], num1[p]) + solve(id2[p], n - sizee[id1[p]])) % mod; m[t] = (temp2 + temp3) % mod; } return m[t]; } int main() { scanf("%d", &n); for (int i = 1; i <= n; i++) scanf("%lld%lld%lld%lld%lld", &id1[i], &id2[i], &num1[i], &num2[i], &l[i]); sizee[0] = 1; for (int i = 1; i <= n; i++) sizee[i] = sizee[id1[i]] + sizee[id2[i]]; for (int i = 1; i <= n; i++) { ll temp1 = solve(id1[i], num1[i]) * (sizee[id2[i]] % mod) % mod; ll temp2 = (sizee[id1[i]] % mod )* (sizee[id2[i]] % mod) % mod * l[i] % mod; ll temp3 = solve(id2[i], num2[i]) * (sizee[id1[i]] % mod) % mod; ans[i] = (((((((temp1 + temp2) % mod) + temp3) % mod) + ans[id1[i]]) % mod) + ans[id2[i]]) % mod; } for (int i = 1; i <= n; i++) printf("%lld\n", ans[i]); return 0; }