P6803-[CEOI2020]星际迷航【博弈论,dp,矩阵乘法】

正题

题目链接:https://www.luogu.com.cn/problem/P6803


题目大意

给出一棵\(n\)个点的树,把它复制出\(D+1\)层,编号为\([0,D]\),然后每一层随机一个点向下一层随机一个点连边。

然后从第\(0\)层的\(1\)号点出发,两个人轮流操作走向一个之前没有走过的点,求有多少种连边方案使得先手必胜。

\(1\leq n\leq 10^5,1\leq D\leq 10^{18}\)


解题思路

我们先只考虑连到下一层的那个点是必胜还是必败的。

显然,连接的下一层的点如果是必胜的,那么局面不会有任何改变。而如果连接的是必败的点,那么原本必败的情况就会变成必胜的情况。

考虑对于每个点算出\(G_{x,0/1}\)表示从\(x\)出发的情况,如果连接的下层点必败,这一层有多少种连边情况会先手必胜/先手必败。

这个东西虽然比较麻烦,但是可以通过换根\(dp\)求出。

然后把所有点的求个和得到\(S_{0,0/1}\)表示下一层连接先手必败的点,这一层得到胜/负的方案。

同样的通过每一个点为根时树的胜负情况得到\(S_{1,0/1}\)表示下一层连接先手必胜的点(局面不会改变),这一层得到胜/负的方案。

这样做\(D\)层我们就可以直接矩阵乘法了。

时间复杂度:\(O(n+\log D)\)


code

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define ll long long
using namespace std;
const ll N=1e5+10,S=2,P=1e9+7;
struct Matrix{
	ll a[S][S];
}s,ans,c;
Matrix operator*(const Matrix &a,const Matrix &b){
	memset(c.a,0,sizeof(c.a));
	for(ll i=0;i<S;i++)
		for(ll j=0;j<S;j++)
			for(ll k=0;k<S;k++)
				(c.a[i][j]+=a.a[i][k]*b.a[k][j]%P)%=P;
	return c;
}
struct node{
	ll to,next;
}a[N<<1];
ll n,d,tot,ls[N],g[N][2];
bool f[N];vector<ll> q[N];
void addl(ll x,ll y){
	a[++tot].to=y;
	a[tot].next=ls[x];
	ls[x]=tot;return;
}
void dfs(ll x,ll fa){
	ll p=0;
	for(ll i=ls[x];i;i=a[i].next){
		ll y=a[i].to;
		if(y==fa)continue;
		dfs(y,x);f[x]|=!f[y];
		if(!f[y])p=(p?-1:y);
		g[x][1]+=g[y][0];
		g[x][0]+=g[y][1];
	}
	if(p==-1)g[x][1]+=g[x][0];
	else if(p>0)g[x][1]+=g[x][0]-g[p][1];
	if(p==-1)g[x][0]=0;
	else if(p>0)g[x][0]=g[p][1];
	g[x][1]++;
	return;
}
void solve(ll x,ll fa,ll p,ll s0,ll s1){
	if(!p)q[x].push_back(fa);
	for(ll i=ls[x];i;i=a[i].next){
		ll y=a[i].to;
		if(y==fa)continue;
		if(!f[y]&&q[x].size()<3)q[x].push_back(y);
		s0+=g[y][1];s1+=g[y][0];
	}
	if(!q[x].size())p=0;
	else if(q[x].size()==1)p=q[x][0];
	else p=-1;
	int pg0=g[x][0],pg1=g[x][1];
	g[x][0]=s0,g[x][1]=s1;
	if(p==-1)g[x][1]+=g[x][0],g[x][0]=0;
	else if(p>0)g[x][1]+=g[x][0]-g[p][1],g[x][0]=g[p][1];
	g[x][1]++;ll _f=(p?1:0);ans.a[0][_f]++;
	s.a[0][0]+=g[x][0];s.a[0][1]+=g[x][1];
	s.a[1][0]+=(!_f)*n;s.a[1][1]+=_f*n;
	for(ll i=ls[x];i;i=a[i].next){
		ll y=a[i].to;
		if(y==fa)continue;
		if(!q[x].size())p=0;
		else if(q[x].size()==1){
			if(q[x][0]==y)p=0;
			else p=q[x][0];
		}
		else if(q[x].size()==2){
			if(q[x][0]==y)p=q[x][1];
			else if(q[x][1]==y)p=q[x][0];
			else p=-1;
		}
		else p=-1;
		g[x][0]=s0-g[y][1];g[x][1]=s1-g[y][0];
		if(p==-1)g[x][1]+=g[x][0],g[x][0]=0;
		else if(p>0)g[x][1]+=g[x][0]-g[p][1],g[x][0]=g[p][1];
		g[x][1]++;solve(y,x,p?1:0,g[x][1],g[x][0]);
	}
	g[x][0]=pg0;g[x][1]=pg1;return;
}
signed main()
{
	scanf("%lld%lld",&n,&d);
	for(ll i=1,x,y;i<n;i++){
		scanf("%lld%lld",&x,&y);
//		x=i+1;y=(i+1)/2;
		addl(x,y);addl(y,x);
	}
	dfs(1,0);
	ll A=g[1][1];
	solve(1,0,1,0,0);
	s.a[0][0]%=P;s.a[0][1]%=P;
	s.a[1][0]%=P;s.a[1][1]%=P;
	d--;
	while(d){
		if(d&1)ans=ans*s;
		s=s*s;d>>=1;
	}
	ll answer=A*ans.a[0][0]%P;
	(answer+=f[1]*n*ans.a[0][1]%P)%=P;
	printf("%lld\n",answer);
	return 0;
}
posted @ 2022-06-08 20:12  QuantAsk  阅读(35)  评论(0编辑  收藏  举报