Codeforces 1111E Tree 虚树 + dp

直接把 r 加进去建虚树, 考虑虚树上的dp, 我们考虑虚树的dfs序的顺序dp过去。

dp[ i ][ j ]  表示到 i 这个点为止, 分成 j 组有多少种合法方案。 

dp[ i ][ j ] = dp[ i - 1 ][ j ] * (j - have[ i ])  + dp[ i - 1 ][ j - 1 ], have[ i ] 表示 i 的祖先中有多少个在a中出现。

 

#include<bits/stdc++.h>
using namespace std;

const int N = (int)1e5 + 7;
const int mod = (int)1e9 + 7;
const int LOG = 17;

int n, q, k, m, r, cas, a[N];
int depth[N], pa[N][LOG];
int in[N], ot[N], dfs_clock;
int col[N], dp[301];
vector<int> G[N];

void dfs(int u, int fa) {
    in[u] = ++dfs_clock;
    depth[u] = depth[fa] + 1;
    pa[u][0] = fa;
    for(int i = 1; i < LOG; i++) {
        pa[u][i] = pa[pa[u][i - 1]][i - 1];
    }
    for(auto &v : G[u]) {
        if(v == fa) continue;
        dfs(v, u);
    }
    ot[u] = dfs_clock;
}

inline int getLca(int u, int v) {
    if(depth[u] < depth[v]) swap(u, v);
    int d = depth[u] - depth[v];
    for(int i = LOG - 1; i >= 0; i--) {
        if(d >> i & 1) {
            u = pa[u][i];
        }
    }
    if(u == v) return u;
    for(int i = LOG - 1; i >= 0; i--) {
        if(pa[u][i] != pa[v][i]) {
            u = pa[u][i];
            v = pa[v][i];
        }
    }
    return pa[u][0];
}


void go(int u, int fa, int have) {
    if(col[u] == cas) {
        for(int i = m; i >= 0; i--) {
            if(i < have + 1) dp[i] = 0;
            else {
                dp[i] = 1LL * dp[i] * (i - have) % mod + dp[i - 1];
                if(dp[i] >= mod) dp[i] -= mod;
            }
        }
    }
    for(auto &v : G[u]) {
        if(v == fa) continue;
        go(v, u, have + (col[u] == cas));
    }
}

int main() {
    scanf("%d%d", &n, &q);
    for(int i = 1; i < n; i++) {
        int u, v;
        scanf("%d%d", &u, &v);
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs(1, 0);
    for(cas = 1; cas <= q; cas++) {
        vector<int> P;
        scanf("%d%d%d", &k, &m, &r);
        for(int i = 1; i <= k; i++) scanf("%d", &a[i]), col[a[i]] = cas;
        a[++k] = r;
        for(int i = 1; i <= k; i++) P.push_back(a[i]);
        sort(a + 1, a + 1 + k, [&](int x, int y) {return in[x] < in[y];});
        for(int i = 1; i < k; i++) P.push_back(getLca(a[i], a[i + 1]));
        sort(P.begin(), P.end());
        P.erase(unique(P.begin(), P.end()), P.end());
        sort(P.begin(), P.end(), [&](int x, int y) {return in[x] < in[y];});
        for(auto &t : P) G[t].clear();
        vector<int> S;
        for(auto &t : P) {
            while(S.size() && ot[S.back()] < in[t]) S.pop_back();
            if(S.size()) {
                G[S.back()].push_back(t);
                G[t].push_back(S.back());
            }
            S.push_back(t);
        }
        for(int i = 0; i <= m; i++) dp[i] = (i == 0);
        go(r, 0, 0);
        int ans = 0;
        for(int i = 1; i <= m; i++) {
            ans += dp[i];
            if(ans >= mod) ans -= mod;
        }
        printf("%d\n", ans);
    }
    return 0;
}

/**
**/

 

posted @ 2019-11-06 00:18  NotNight  阅读(253)  评论(0编辑  收藏  举报