luogu P5643 [PKUWC2018]随机游走

https://www.luogu.com.cn/problem/P5643

首先肯定是 min ⁡ − max ⁡ \min-\max minmax容斥
即求出对于一个集合 S , u S,u Su出发第一次到集合中的点的期望步数

对于 u ∈ S u\in S uS 显然 f [ u ] = 0 f[u]=0 f[u]=0

否则 f [ u ] = 1 d [ u ] ( f [ f a ] + ∑ v ∈ s o n f [ v ] ) + 1 f[u]=\frac{1}{d[u]}(f[fa]+\sum_{v\in son} f[v])+1 f[u]=d[u]1(f[fa]+vsonf[v])+1

高斯消元太慢了,考虑树形DP,
按照套路,将转移方程写成关于父亲的一次函数
f [ u ] = K [ u ] f [ f a ] + B [ u ] f[u]=K[u]f[fa]+B[u] f[u]=K[u]f[fa]+B[u]
f [ u ] = 1 d [ u ] ( f [ f a ] + ∑ v ∈ s o n ( K [ v ] f [ u ] + B [ v ] ) ) + 1 f[u]=\frac{1}{d[u]}(f[fa]+\sum_{v\in son} (K[v]f[u]+B[v]))+1 f[u]=d[u]1(f[fa]+vson(K[v]f[u]+B[v]))+1
         = 1 d [ u ] ( f [ f a ] + s u m K [ u ] f [ u ] + s u m B [ u ] ) + 1 \ \ \ \ ~~~~=\frac{1}{d[u]}(f[fa]+sumK[u]f[u]+sumB[u])+1         =d[u]1(f[fa]+sumK[u]f[u]+sumB[u])+1

d [ u ] f [ u ] = f [ f a ] + s u m K [ u ] f [ u ] + s u m B [ u ] + d [ u ] d[u]f[u]=f[fa]+sumK[u]f[u]+sumB[u]+d[u] d[u]f[u]=f[fa]+sumK[u]f[u]+sumB[u]+d[u]

( d [ u ] − s u m K [ u ] ) f [ u ] = f [ f a ] + s u m B [ u ] + d [ u ] (d[u]-sumK[u])f[u]=f[fa]+sumB[u]+d[u] (d[u]sumK[u])f[u]=f[fa]+sumB[u]+d[u]
f [ u ] = f [ f a ] + s u m B [ u ] + d [ u ] ( d [ u ] − s u m K [ u ] ) f[u]=\frac{f[fa]+sumB[u]+d[u]}{(d[u]-sumK[u])} f[u]=(d[u]sumK[u])f[fa]+sumB[u]+d[u]

K [ u ] = 1 ( d [ u ] − s u m K [ u ] ) , B [ u ] = s u m B [ u ] + d [ u ] ( d [ u ] − s u m K [ u ] ) K[u]=\frac{1}{(d[u]-sumK[u])},B[u]=\frac{sumB[u]+d[u]}{(d[u]-sumK[u])} K[u]=(d[u]sumK[u])1,B[u]=(d[u]sumK[u])sumB[u]+d[u]

K , B K,B KB显然可以跑一遍DP求出来

那么 E [ m i n ( S ) ] = f [ r o o t ] = B [ r o o t ] E[min(S)]=f[root]=B[root] E[min(S)]=f[root]=B[root]

然后再看 min ⁡ − max ⁡ \min-\max minmax容斥的式子

E [ m a x ( S ) ] = ∑ T ⊆ S ( − 1 ) ∣ T ∣ − 1 E [ m i n ( T ) ] E[max(S)]=\sum_{T⊆S}(-1)^{|T|-1}E[min(T)] E[max(S)]=TS(1)T1E[min(T)]
这个直接跑个高位前缀和即可(FMT)

code:

#include<bits/stdc++.h>
#define N 19
#define ll long long
#define mod 998244353
using namespace std;
const int U = (1 << N) + 5;
struct edge {
    int v, nxt;
} e[N << 1];
int p[N], eid;
void init() {
    memset(p, -1, sizeof p);
    eid = 0;
}
void insert(int u, int v) {
    e[eid].v = v;
    e[eid].nxt = p[u];
    p[u] = eid ++;
}
ll qpow(ll x, ll y) {
    ll ret = 1;
    for(; y; y >>= 1, x = x * x % mod) if(y & 1) ret = ret * x % mod;
    return ret;
}
int B[N], K[N], col[N], d[N], n, q, rt, f[U];
void dfs(int u, int fa) {
    K[u] = B[u] = 0;
    if(col[u]) return ;
    for(int i = p[u]; i + 1; i = e[i].nxt) {
        int v = e[i].v;
        if(v == fa) continue;
        dfs(v, u);
        B[u] = (B[u] + B[v]) % mod;
        K[u] = (K[u] + K[v]) % mod;
    }
    K[u] = qpow(d[u] - K[u] + mod, mod - 2);
    B[u] = 1ll * (d[u] + B[u]) % mod * K[u] % mod;
}
int main() {
    init();
    scanf("%d%d%d", &n, &q, &rt);
    for(int i = 1; i < n; i ++) {
        int u, v;
        scanf("%d%d", &u, &v);
        d[u] ++, d[v] ++;
        insert(u, v), insert(v, u);
    }
    for(int S = 0; S < (1 << n); S ++) {
        int cnt = 0;
        for(int i = 0; i < n; i ++) col[i + 1] = (S >> i) & 1, cnt += col[i + 1];
        dfs(rt, rt);
        f[S] = cnt & 1? B[rt] : (mod - B[rt]) % mod;
    }
  //  for(int S = 0; S < (1 << n); S ++) printf("%d ", f[S]); printf("\n\n");
    for(int i = 0; i < n; i ++)
        for(int S = 0; S < (1 << n); S ++)
            if((S >> i) & 1)
                f[S] = (f[S] + f[S ^ (1 << i)]) % mod;
    while(q --) {
        int k, x, S = 0;
        scanf("%d", &k);
        while(k --) {
            scanf("%d", &x);
            S |= (1 << (x - 1));
        }
        printf("%d\n", f[S]);
    }
    return 0;
}

posted @ 2021-09-26 08:10  lahlah  阅读(36)  评论(0编辑  收藏  举报