Codeforces 494D Birthday 树形dp (看题解)
没想到平方和能在树上dp出来的。。。
知道了如何转移, 那么就很好写了。。。
#include<bits/stdc++.h> #define LL long long #define LD long double #define ull unsigned long long #define fi first #define se second #define mk make_pair #define PLL pair<LL, LL> #define PLI pair<LL, int> #define PII pair<int, int> #define SZ(x) ((int)x.size()) #define ALL(x) (x).begin(), (x).end() #define fio ios::sync_with_stdio(false); cin.tie(0); using namespace std; const int N = 1e5 + 7; const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3f; const int mod = 1e9 + 7; const double eps = 1e-8; const double PI = acos(-1); template<class T, class S> inline void add(T& a, S b) {a += b; if(a >= mod) a -= mod;} template<class T, class S> inline void sub(T& a, S b) {a -= b; if(a < 0) a += mod;} template<class T, class S> inline bool chkmax(T& a, S b) {return a < b ? a = b, true : false;} template<class T, class S> inline bool chkmin(T& a, S b) {return a > b ? a = b, true : false;} int n, q; int pa[N][20], len[N][20], depth[N]; int allDis2[N]; vector<PII> G[N]; struct dpNode { dpNode() {cnt = sumDis = sumDis2 = 0;} dpNode(int cnt, int sumDis, int sumDis2) : cnt(cnt), sumDis(sumDis), sumDis2(sumDis2) {} int cnt, sumDis, sumDis2; void print() { printf("cnt: %d sumDis: %d sumDis2: %d\n", cnt, sumDis, sumDis2); } } dp[N], dp2[N], INIT(1, 0, 0); dpNode mergeTwo(dpNode a, dpNode b, int w, int op) { if(op > 0) { a.cnt += b.cnt; add(a.sumDis, b.sumDis); add(a.sumDis, 1LL * b.cnt * w % mod); add(a.sumDis2, b.sumDis2); add(a.sumDis2, 1LL * b.sumDis * 2 * w % mod); add(a.sumDis2, 1LL * w * w % mod * b.cnt % mod); } else { a.cnt -= b.cnt; sub(a.sumDis, b.sumDis); sub(a.sumDis, 1LL * b.cnt * w % mod); sub(a.sumDis2, b.sumDis2); sub(a.sumDis2, 1LL * b.sumDis * 2 * w % mod); sub(a.sumDis2, 1LL * w * w % mod * b.cnt % mod); } return a; } void dfs(int u, int fa, int disTofa) { depth[u] = depth[fa] + 1; pa[u][0] = fa; len[u][0] = disTofa; dp[u].cnt = 1; for(int i = 1; i < 20; i++) { pa[u][i] = pa[pa[u][i - 1]][i - 1]; len[u][i] = (len[u][i - 1] + len[pa[u][i - 1]][i - 1]) % mod; } for(auto &e : G[u]) { int v = e.se; if(v == fa) continue; dfs(v, u, e.fi); dp[u] = mergeTwo(dp[u], dp[v], e.fi, 1); } } PII getLca(int u, int v) { if(depth[u] < depth[v]) swap(u, v); int dis = depth[u] - depth[v]; int ret = 0; for(int i = 19; i >= 0; i--) if(dis >> i & 1) add(ret, len[u][i]), u = pa[u][i]; if(u == v) return mk(ret, u); for(int i = 19; i >= 0; i--) { if(pa[u][i] != pa[v][i]) { add(ret, len[u][i]); add(ret, len[v][i]); u = pa[u][i]; v = pa[v][i]; } } add(ret, len[u][0]); add(ret, len[v][0]); return mk(ret, pa[u][0]); } void dfs2(int u, int fa, dpNode up) { dp2[u] = up; dp2[u].cnt--; allDis2[u] = (dp[u].sumDis2 + up.sumDis2) % mod; for(auto &e : G[u]) { if(e.se == fa) continue; up = mergeTwo(up, dp[e.se], e.fi, 1); } for(auto &e : G[u]) { if(e.se == fa) continue; up = mergeTwo(up, dp[e.se], e.fi, -1); dfs2(e.se, u, mergeTwo(INIT, up, e.fi, 1)); up = mergeTwo(up, dp[e.se], e.fi, 1); } } int main() { scanf("%d", &n); for(int i = 1; i < n; i++) { int u, v, w; scanf("%d%d%d", &u, &v, &w); G[u].push_back(mk(w, v)); G[v].push_back(mk(w, u)); } dfs(1, 0, 0); dfs2(1, 0, INIT); scanf("%d", &q); while(q--) { int u, v; scanf("%d%d", &u, &v); PII ret = getLca(u, v); int lca = ret.se, dis = ret.fi; int ans = 0; if(lca != v) { int x = mergeTwo(INIT, dp[v], dis, 1).sumDis2; int z = allDis2[u]; ans = ((2 * x - z) % mod + mod) % mod; } else { int y = mergeTwo(INIT, dp2[v], dis, 1).sumDis2; int z = allDis2[u]; ans = ((z - 2 * y) % mod + mod) % mod; } printf("%d\n", ans); } return 0; } /* */