Codeforces 494D Birthday 树形dp (看题解)

Birthday

没想到平方和能在树上dp出来的。。。

知道了如何转移, 那么就很好写了。。。

#include<bits/stdc++.h>
#define LL long long
#define LD long double
#define ull unsigned long long
#define fi first
#define se second
#define mk make_pair
#define PLL pair<LL, LL>
#define PLI pair<LL, int>
#define PII pair<int, int>
#define SZ(x) ((int)x.size())
#define ALL(x) (x).begin(), (x).end()
#define fio ios::sync_with_stdio(false); cin.tie(0);

using namespace std;

const int N = 1e5 + 7;
const int inf = 0x3f3f3f3f;
const LL INF = 0x3f3f3f3f3f3f3f3f;
const int mod = 1e9 + 7;
const double eps = 1e-8;
const double PI = acos(-1);

template<class T, class S> inline void add(T& a, S b) {a += b; if(a >= mod) a -= mod;}
template<class T, class S> inline void sub(T& a, S b) {a -= b; if(a < 0) a += mod;}
template<class T, class S> inline bool chkmax(T& a, S b) {return a < b ? a = b, true : false;}
template<class T, class S> inline bool chkmin(T& a, S b) {return a > b ? a = b, true : false;}

int n, q;
int pa[N][20], len[N][20], depth[N];
int allDis2[N];

vector<PII> G[N];

struct dpNode {
    dpNode() {cnt = sumDis = sumDis2 = 0;}
    dpNode(int cnt, int sumDis, int sumDis2) : cnt(cnt), sumDis(sumDis), sumDis2(sumDis2) {}
    int cnt, sumDis, sumDis2;
    void print() {
        printf("cnt: %d  sumDis: %d  sumDis2: %d\n", cnt, sumDis, sumDis2);
    }
} dp[N], dp2[N], INIT(1, 0, 0);

dpNode mergeTwo(dpNode a, dpNode b, int w, int op) {
    if(op > 0) {
        a.cnt += b.cnt;
        add(a.sumDis, b.sumDis);
        add(a.sumDis, 1LL * b.cnt * w % mod);
        add(a.sumDis2, b.sumDis2);
        add(a.sumDis2, 1LL * b.sumDis * 2 * w % mod);
        add(a.sumDis2, 1LL * w * w % mod * b.cnt % mod);
    } else {
        a.cnt -= b.cnt;
        sub(a.sumDis, b.sumDis);
        sub(a.sumDis, 1LL * b.cnt * w % mod);
        sub(a.sumDis2, b.sumDis2);
        sub(a.sumDis2, 1LL * b.sumDis * 2 * w % mod);
        sub(a.sumDis2, 1LL * w * w % mod * b.cnt % mod);
    }
    return a;
}

void dfs(int u, int fa, int disTofa) {
    depth[u] = depth[fa] + 1;
    pa[u][0] = fa;
    len[u][0] = disTofa;
    dp[u].cnt = 1;
    for(int i = 1; i < 20; i++) {
        pa[u][i] = pa[pa[u][i - 1]][i - 1];
        len[u][i] = (len[u][i - 1] + len[pa[u][i - 1]][i - 1]) % mod;
    }
    for(auto &e : G[u]) {
        int v = e.se;
        if(v == fa) continue;
        dfs(v, u, e.fi);
        dp[u] = mergeTwo(dp[u], dp[v], e.fi, 1);
    }
}

PII getLca(int u, int v) {
    if(depth[u] < depth[v]) swap(u, v);
    int dis = depth[u] - depth[v];
    int ret = 0;
    for(int i = 19; i >= 0; i--)
        if(dis >> i & 1) add(ret, len[u][i]), u = pa[u][i];
    if(u == v) return mk(ret, u);
    for(int i = 19; i >= 0; i--) {
        if(pa[u][i] != pa[v][i]) {
            add(ret, len[u][i]);
            add(ret, len[v][i]);
            u = pa[u][i];
            v = pa[v][i];
        }
    }
    add(ret, len[u][0]);
    add(ret, len[v][0]);
    return mk(ret, pa[u][0]);
}


void dfs2(int u, int fa, dpNode up) {
    dp2[u] = up; dp2[u].cnt--;
    allDis2[u] = (dp[u].sumDis2 + up.sumDis2) % mod;
    for(auto &e : G[u]) {
        if(e.se == fa) continue;
        up = mergeTwo(up, dp[e.se], e.fi, 1);
    }
    for(auto &e : G[u]) {
        if(e.se == fa) continue;
        up = mergeTwo(up, dp[e.se], e.fi, -1);
        dfs2(e.se, u, mergeTwo(INIT, up, e.fi, 1));
        up = mergeTwo(up, dp[e.se], e.fi, 1);
    }
}

int main() {
    scanf("%d", &n);
    for(int i = 1; i < n; i++) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        G[u].push_back(mk(w, v));
        G[v].push_back(mk(w, u));
    }
    dfs(1, 0, 0);
    dfs2(1, 0, INIT);
    
    scanf("%d", &q);
    while(q--) {
        int u, v;
        scanf("%d%d", &u, &v);
        PII ret = getLca(u, v);
        int lca = ret.se, dis = ret.fi;
        int ans = 0;

        if(lca != v) {
            int x = mergeTwo(INIT, dp[v], dis, 1).sumDis2;
            int z = allDis2[u];
            ans = ((2 * x - z) % mod + mod) % mod;
        } else {
            int y = mergeTwo(INIT, dp2[v], dis, 1).sumDis2;
            int z = allDis2[u];
            ans = ((z - 2 * y) % mod + mod) % mod;
        }
        printf("%d\n", ans);
    }
    return 0;
}

/*
*/

 

posted @ 2019-06-07 16:54  NotNight  阅读(153)  评论(0编辑  收藏  举报