Loj #2542. 「PKUWC2018」随机游走
Loj #2542. 「PKUWC2018」随机游走
题目描述
给定一棵 \(n\) 个结点的树,你从点 \(x\) 出发,每次等概率随机选择一条与所在点相邻的边走过去。
有 \(Q\) 次询问,每次询问给定一个集合 \(S\),求如果从 \(x\) 出发一直随机游走,直到点集 \(S\) 中所有点都至少经过一次的话,期望游走几步。
特别地,点 \(x\)(即起点)视为一开始就被经过了一次。
答案对 $998244353 $ 取模。
输入格式
第一行三个正整数 \(n,Q,x\)。
接下来 \(n-1\) 行,每行两个正整数 \((u,v)\) 描述一条树边。
接下来 \(Q\) 行,每行第一个数 \(k\) 表示集合大小,接下来 \(k\) 个互不相同的数表示集合 \(S\)。
输出格式
输出 \(Q\) 行,每行一个非负整数表示答案。
数据范围与提示
对于 \(20\%\) 的数据,有 \(1\leq n,Q\leq 5\)。
另有 \(10\%\) 的数据,满足给定的树是一条链。
另有 \(10\%\) 的数据,满足对于所有询问有 \(k=1\)。
另有 \(30\%\) 的数据,满足 \(1\leq n\leq 10 ,Q=1\)。
对于 \(100\%\) 的数据,有 \(1\leq n\leq 18\),\(1\leq Q\leq 5000\),\(1\leq k\leq n\)。
首先根据\(\min-\max\) 反演我们知道:
\[\max(S)=\sum_{T\subseteq S}(-1)^{|T|-1}\min(T)
\]
设\(f_{v,S}\)表示从\(v\)出发,经过\(S\)中至少一个节点的期望步数。
如果\(v\in S\),\(f_{v,S}=0\),否则:
\[f_v=1+\frac{1}{d_v}f_{fa}+\frac{1}{d_v}\sum f(u)\\
\]
然后这是颗树,我们可以将\(DP\)方程移项变成只与\(fa\)的\(f\)值个一个常数有关。
设:
\[f(v)=A_v*f_{fa}+B_v\\
\]
带回去化简:
\[f_v=1+\frac{1}{d_v}f_{fa}+\frac{1}{d_v}\sum (A_u*f_v+B_u)\\
(d_v-\sum A_u)*f_v=d_v+f_{fa}+\sum B_u\\
f_v=\frac{1}{d_v-\sum A_u}*f_{fa}+\frac{d_v+\sum B_u}{d_v-\sum A_u}
\]
得到:
\[A_v=\frac{1}{d_v-\sum A_u},B_v=\frac{d_v+\sum B_u}{d_v-\sum A_u}
\]
代码:
#include<bits/stdc++.h>
#define ll long long
#define N 19
using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}
const ll mod=998244353;
ll ksm(ll t,ll x) {
ll ans=1;
for(;x;x>>=1,t=t*t%mod)
if(x&1) ans=ans*t%mod;
return ans;
}
int n;
int X,m;
struct road {int to,nxt;}s[N<<1];
int h[N],cnt;
void add(int i,int j) {s[++cnt]=(road) {j,h[i]};h[i]=cnt;}
ll w[N];
int d[N];
ll A[N],B[N],f[N];
int tag[N];
void dfs(int v,int fa) {
A[v]=B[v]=0;
ll sumA=0,sumB=0;
for(int i=h[v];i;i=s[i].nxt) {
int to=s[i].to;
if(to==fa) continue ;
dfs(to,v);
sumA=(sumA+A[to])%mod;
sumB=(sumB+B[to])%mod;
}
if(tag[v]) A[v]=B[v]=0;
else A[v]=ksm(d[v]-sumA+mod,mod-2),B[v]=(d[v]+sumB)*ksm(d[v]-sumA+mod,mod-2)%mod;
}
void dfs2(int v,int fa) {
f[v]=(A[v]*f[fa]+B[v])%mod;
for(int i=h[v];i;i=s[i].nxt) {
int to=s[i].to;
if(to==fa) continue ;
dfs2(to,v);
}
}
int mn[1<<N];
int main() {
n=Get(),m=Get(),X=Get();
int a,b;
for(int i=1;i<n;i++) {
a=Get(),b=Get();
add(a,b),add(b,a);
d[a]++,d[b]++;
}
for(int S=1;S<1<<n;S++) {
for(int i=1;i<=n;i++) if(S>>i-1&1) tag[i]=1;
dfs(1,0),dfs2(1,0);
mn[S]=f[X];
for(int i=1;i<=n;i++) if(S>>i-1&1) tag[i]=0;
}
for(int S=1;S<1<<n;S++) {
int cnt=0;
for(int i=0;i<n;i++) cnt+=S>>i&1;
if(!(cnt&1)) mn[S]=(mod-mn[S])%mod;
}
for(int i=0;i<n;i++) {
for(int S=1;S<1<<n;S++) {
if(S>>i&1) mn[S]=(mn[S]+mn[S^(1<<i)]+mod)%mod;
}
}
while(m--) {
int k=Get();
int sta=0;
while(k--) sta|=1<<Get()-1;
cout<<mn[sta]<<"\n";
}
return 0;
}