[loj2542]「PKUWC2018」随机游走——min-max容斥+树上高消

题目大意:

给定一棵 n 个结点的树,你从点 x 出发,每次等概率随机选择一条与所在点相邻的边走过去。
有 Q 次询问,每次询问给定一个集合 S,求如果从 x 出发一直随机游走,直到点集 S 中所有点都至少经过一次的话,期望游走几步。
特别地,点 x(即起点)视为一开始就被经过了一次。
答案对 998244353 取模。

思路:

看到所有点都经过一次直接上min-max反演。
然后我们要求每个集合的min,即标记关键点之后求期望意义下第一次到关键点要走的步数,上dp的话发现转移式子有环,树上高斯消元即可。
为了方便我们可以把所有答案都处理出来,但是这样复杂度为\(3^n\),常数大可能通过不了,发现我们需要求的就是每个状态的子集和,上FWT和FMT都可以将时间复杂度优化到\(2^n\times n\)

/*=======================================
 * Author : ylsoi
 * Time : 2019.2.10
 * Problem : loj2542
 * E-mail : ylsoi@foxmail.com
 * ====================================*/
#include<bits/stdc++.h>

#define REP(i,a,b) for(register int i=a,i##_end_=b;i<=i##_end_;++i)
#define DREP(i,a,b) for(register int i=a,i##_end_=b;i>=i##_end_;--i)
#define debug(x) cout<<#x<<"="<<x<<" "
#define fi first
#define se second
#define mk make_pair
#define pb push_back
typedef long long ll;

using namespace std;

void File(){
	freopen("loj2542.in","r",stdin);
	freopen("loj2542.out","w",stdout);
}

template<typename T>void read(T &_){
	_=0; T f=1; char c=getchar();
	for(;!isdigit(c);c=getchar())if(c=='-')f=-1;
	for(;isdigit(c);c=getchar())_=(_<<1)+(_<<3)+(c^'0');
	_*=f;
}

const int maxn=18+5;
const int maxw=(1<<19)+10;
const int mod=998244353;
int n,q,rt,all,cnt[maxw];
vector<int>G[maxn];
ll d[maxn],a[maxn],b[maxn],mx[maxw];
bool c[maxn];

inline ll qpow(register ll x,register ll y){
	ll ret=1; x%=mod;
	while(y){
		if(y&1)ret=ret*x%mod;
		x=x*x%mod;
		y>>=1;
	}
	return ret;
}

inline ll inv(register ll x){return qpow(x,mod-2);}

inline void dfs(register int u,register int fh){
	ll sa=0,sb=0;
	REP(i,0,G[u].size()-1){
		int v=G[u][i];
		if(v==fh)continue;
		dfs(v,u);
		sa=(sa+a[v])%mod;
		sb=(sb+b[v])%mod;
	}
	if(c[u])a[u]=b[u]=0;
	else{
		a[u]=inv(d[u]-sa);
		b[u]=(sb+d[u])*a[u]%mod;
	}
}

int main(){
	File();
	read(n),read(q),read(rt);
	all=(1<<n)-1;
	int u,v;
	REP(i,1,n-1){
		read(u),read(v);
		G[u].pb(v),++d[u];
		G[v].pb(u),++d[v];
	}
	c[2]=1;
	dfs(rt,0);
	REP(S,1,all)cnt[S]=__builtin_popcount(S)%2 ? 1 : -1;
	REP(S,1,all){
		REP(i,1,n)if((1<<(i-1))&S)c[i]=1;
		else c[i]=0;
		dfs(rt,0);
		mx[S]=b[rt]*cnt[S];
	}
	for(int len=1;len<=all;len<<=1)
		for(int L=0;L<=all;L+=len<<1)
			REP(i,L,L+len-1)
				mx[i+len]=(mx[i+len]+mx[i])%mod;
	int S=0,sz,x;
	REP(i,1,q){
		S=0;
		read(sz);
		REP(j,1,sz)read(x),S^=1<<(x-1);
		printf("%lld\n",(mx[S]+mod)%mod);
	}
	return 0;
}

posted @ 2019-02-10 20:30  ylsoi  阅读(200)  评论(0编辑  收藏  举报