Codeforces 1032F Vasya and Maximum Matching dp

Vasya and Maximum Matching

首先能观察出, 只有完美匹配的情况下方案数唯一。

dp[ i ][ 0 ], dp[ i ][ 1 ], dp[ i ][ 2 ] 分别表示

 对于 i 这棵子树   0: 不向上连边完成  1:向上连边完成  2:向上连边未完成   的方案数

#include<bits/stdc++.h>
#define LL long long
#define LD long double
#define ull unsigned long long
#define fi first
#define se second
#define mk make_pair
#define PLL pair<LL, LL>
#define PLI pair<LL, int>
#define PII pair<int, int>
#define SZ(x) ((int)x.size())
#define ALL(x) (x).begin(), (x).end()
#define fio ios::sync_with_stdio(false); cin.tie(0);

using namespace std;

const int N = 3e5 + 7;
const int inf = 0x3f3f3f3f;
const LL INF = 0x3f3f3f3f3f3f3f3f;
const int mod = 998244353;
const double eps = 1e-8;
const double PI = acos(-1);

template<class T, class S> inline void add(T& a, S b) {a += b; if(a >= mod) a -= mod;}
template<class T, class S> inline void sub(T& a, S b) {a -= b; if(a < 0) a += mod;}
template<class T, class S> inline bool chkmax(T& a, S b) {return a < b ? a = b, true : false;}
template<class T, class S> inline bool chkmin(T& a, S b) {return a > b ? a = b, true : false;}

int n;
vector<int> G[N];

LL power(LL a, LL b) {
    LL ans = 1;
    while(b) {
        if(b & 1) ans = ans * a % mod;
        a = a * a % mod; b >>= 1;
    }
    return ans;
}

LL dp[N][3];

void go(int u, int fa) {
    int pos = -1;
    for(int i = 0; i < SZ(G[u]); i++) {
        if(G[u][i] == fa) {
            pos = i;
            continue;
        }
        go(G[u][i], u);
    }
    if(~pos) {
        swap(G[u][pos], G[u][SZ(G[u]) - 1]);
        G[u].pop_back();
    }
}

void dfs(int u) {
    dp[u][0] = 1;
    dp[u][1] = 0;
    dp[u][2] = 1;
    if(!SZ(G[u])) return;
    for(auto& v : G[u]) dfs(v);
    int cnts = SZ(G[u]);
    vector<LL> prefix[3];
    for(int i = 0; i < 3; i++) {
        prefix[i].resize(cnts);
        for(int j = 0; j < cnts; j++) {
            int v = G[u][j];
            if(!j) prefix[i][j] = dp[v][i];
            else prefix[i][j] = prefix[i][j - 1] * dp[v][i] % mod;
        }
    }
    vector<LL> prefix01(cnts);
    vector<LL> suffix01(cnts);
    for(int i = 0; i < cnts; i++) {
        int v = G[u][i];
        if(!i) prefix01[i] = (dp[v][0] + dp[v][1]) % mod;
        else prefix01[i] = prefix01[i - 1] * (dp[v][0] + dp[v][1]) % mod;
    }
    for(int i = cnts - 1; i >= 0; i--) {
        int v = G[u][i];
        if(i == cnts - 1) suffix01[i] = (dp[v][0] + dp[v][1]) % mod;
        else suffix01[i] = suffix01[i + 1] * (dp[v][0] + dp[v][1]) % mod;
    }
// 0: 不向上连边完成  1:向上连边完成  2:向上连边未完成

    dp[u][0] = prefix[0][cnts - 1];
    for(int i = 0; i < cnts; i++) {
        int v = G[u][i];
        LL tmp = dp[v][2];
        if(i - 1 >= 0) tmp = tmp * prefix01[i - 1] % mod;
        if(i + 1 < cnts) tmp = tmp * suffix01[i + 1] % mod;
        add(dp[u][0], tmp);
        add(dp[u][1], tmp);
    }
    dp[u][2] = prefix01[cnts - 1];
}

int main() {
    scanf("%d", &n);
    for(int i = 1; i < n; i++) {
        int u, v;
         scanf("%d%d", &u, &v);
         G[u].push_back(v);
         G[v].push_back(u);
    }
    go(1, 0);
    dfs(1);
    printf("%lld\n", dp[1][0]);
    return 0;
}

/*
*/

 

posted @ 2019-05-06 17:15  NotNight  阅读(167)  评论(0编辑  收藏  举报