luogu P5643 [PKUWC2018]随机游走
https://www.luogu.com.cn/problem/P5643
首先肯定是
min
−
max
\min-\max
min−max容斥
即求出对于一个集合
S
,
u
S,u
S,u出发第一次到集合中的点的期望步数
对于 u ∈ S u\in S u∈S 显然 f [ u ] = 0 f[u]=0 f[u]=0
否则 f [ u ] = 1 d [ u ] ( f [ f a ] + ∑ v ∈ s o n f [ v ] ) + 1 f[u]=\frac{1}{d[u]}(f[fa]+\sum_{v\in son} f[v])+1 f[u]=d[u]1(f[fa]+v∈son∑f[v])+1
高斯消元太慢了,考虑树形DP,
按照套路,将转移方程写成关于父亲的一次函数
f
[
u
]
=
K
[
u
]
f
[
f
a
]
+
B
[
u
]
f[u]=K[u]f[fa]+B[u]
f[u]=K[u]f[fa]+B[u]
f
[
u
]
=
1
d
[
u
]
(
f
[
f
a
]
+
∑
v
∈
s
o
n
(
K
[
v
]
f
[
u
]
+
B
[
v
]
)
)
+
1
f[u]=\frac{1}{d[u]}(f[fa]+\sum_{v\in son} (K[v]f[u]+B[v]))+1
f[u]=d[u]1(f[fa]+∑v∈son(K[v]f[u]+B[v]))+1
=
1
d
[
u
]
(
f
[
f
a
]
+
s
u
m
K
[
u
]
f
[
u
]
+
s
u
m
B
[
u
]
)
+
1
\ \ \ \ ~~~~=\frac{1}{d[u]}(f[fa]+sumK[u]f[u]+sumB[u])+1
=d[u]1(f[fa]+sumK[u]f[u]+sumB[u])+1
d [ u ] f [ u ] = f [ f a ] + s u m K [ u ] f [ u ] + s u m B [ u ] + d [ u ] d[u]f[u]=f[fa]+sumK[u]f[u]+sumB[u]+d[u] d[u]f[u]=f[fa]+sumK[u]f[u]+sumB[u]+d[u]
(
d
[
u
]
−
s
u
m
K
[
u
]
)
f
[
u
]
=
f
[
f
a
]
+
s
u
m
B
[
u
]
+
d
[
u
]
(d[u]-sumK[u])f[u]=f[fa]+sumB[u]+d[u]
(d[u]−sumK[u])f[u]=f[fa]+sumB[u]+d[u]
f
[
u
]
=
f
[
f
a
]
+
s
u
m
B
[
u
]
+
d
[
u
]
(
d
[
u
]
−
s
u
m
K
[
u
]
)
f[u]=\frac{f[fa]+sumB[u]+d[u]}{(d[u]-sumK[u])}
f[u]=(d[u]−sumK[u])f[fa]+sumB[u]+d[u]
K [ u ] = 1 ( d [ u ] − s u m K [ u ] ) , B [ u ] = s u m B [ u ] + d [ u ] ( d [ u ] − s u m K [ u ] ) K[u]=\frac{1}{(d[u]-sumK[u])},B[u]=\frac{sumB[u]+d[u]}{(d[u]-sumK[u])} K[u]=(d[u]−sumK[u])1,B[u]=(d[u]−sumK[u])sumB[u]+d[u]
K , B K,B K,B显然可以跑一遍DP求出来
那么 E [ m i n ( S ) ] = f [ r o o t ] = B [ r o o t ] E[min(S)]=f[root]=B[root] E[min(S)]=f[root]=B[root]
然后再看 min − max \min-\max min−max容斥的式子
E
[
m
a
x
(
S
)
]
=
∑
T
⊆
S
(
−
1
)
∣
T
∣
−
1
E
[
m
i
n
(
T
)
]
E[max(S)]=\sum_{T⊆S}(-1)^{|T|-1}E[min(T)]
E[max(S)]=T⊆S∑(−1)∣T∣−1E[min(T)]
这个直接跑个高位前缀和即可(FMT)
code:
#include<bits/stdc++.h>
#define N 19
#define ll long long
#define mod 998244353
using namespace std;
const int U = (1 << N) + 5;
struct edge {
int v, nxt;
} e[N << 1];
int p[N], eid;
void init() {
memset(p, -1, sizeof p);
eid = 0;
}
void insert(int u, int v) {
e[eid].v = v;
e[eid].nxt = p[u];
p[u] = eid ++;
}
ll qpow(ll x, ll y) {
ll ret = 1;
for(; y; y >>= 1, x = x * x % mod) if(y & 1) ret = ret * x % mod;
return ret;
}
int B[N], K[N], col[N], d[N], n, q, rt, f[U];
void dfs(int u, int fa) {
K[u] = B[u] = 0;
if(col[u]) return ;
for(int i = p[u]; i + 1; i = e[i].nxt) {
int v = e[i].v;
if(v == fa) continue;
dfs(v, u);
B[u] = (B[u] + B[v]) % mod;
K[u] = (K[u] + K[v]) % mod;
}
K[u] = qpow(d[u] - K[u] + mod, mod - 2);
B[u] = 1ll * (d[u] + B[u]) % mod * K[u] % mod;
}
int main() {
init();
scanf("%d%d%d", &n, &q, &rt);
for(int i = 1; i < n; i ++) {
int u, v;
scanf("%d%d", &u, &v);
d[u] ++, d[v] ++;
insert(u, v), insert(v, u);
}
for(int S = 0; S < (1 << n); S ++) {
int cnt = 0;
for(int i = 0; i < n; i ++) col[i + 1] = (S >> i) & 1, cnt += col[i + 1];
dfs(rt, rt);
f[S] = cnt & 1? B[rt] : (mod - B[rt]) % mod;
}
// for(int S = 0; S < (1 << n); S ++) printf("%d ", f[S]); printf("\n\n");
for(int i = 0; i < n; i ++)
for(int S = 0; S < (1 << n); S ++)
if((S >> i) & 1)
f[S] = (f[S] + f[S ^ (1 << i)]) % mod;
while(q --) {
int k, x, S = 0;
scanf("%d", &k);
while(k --) {
scanf("%d", &x);
S |= (1 << (x - 1));
}
printf("%d\n", f[S]);
}
return 0;
}