[ZJOI2019] Minimax搜索

一、题目

点此看题

二、解法

\(\tt md\) 这题真的把我心态整炸了,真的太神了,理解都搞了整整一个晚上。

注意本题只需要改变根节点的值,我们可以预处理出 \(dp[u]\) 表示 \(u\) 节点最初的权值,然后设 \(W=dp[1]\),考虑如果 \(W\)\(S\) 中那么代价一定是 \(1\),这种情况是平凡的。

第一个转化是我们本来想求 \(\max_{i\in S}|i-w_i|=k\) 的方案数,但是我们可以差分一下求出 \(\max_{i\in S}|i-w_i|\leq k\) 的方案数,也就是 \(\forall |i-w_i|\leq k\),这样就变成了关于每个点的限制。

考虑 \(dp[u]=W\)\(u\) 一定构成原树上的一条链,那么我们只要改变这条链上任意一个节点的值就可以改变根,那么我们可以把这条链断开,对于每个连通块单独讨论,然后用乘法原理合并即可。

继续考虑某个点 \(x\)\(W\) 链对应的根是 \(rt\),如果 \(rt\) 的深度是奇数,那么我们只有 \(x<W\) 把变成 \(x>W\) 才是可能有用的,\(x>W\) 不需要变化;如果 \(rt\) 的深度是偶数,那么我们只有 \(x>W\) 把变成 \(x<W\) 才是可能有用,\(x<W\) 不需要变化(可以把 \(<W\) 看成 \(0\)\(>W\) 看成 \(1\) 来理解这个结论)。

那么我们枚举 \(k\) 并且知道 \(S\) 之后可以贪心地确定每个点的取值。现在回到计数问题上来,考虑 \(rt\) 是奇数的情况,我们设 \(f[u]\) 表示只能使得 \(dp[u]<W\)\(S\) 数量,设 \(cnt[u]\) 表示 \(u\) 节点以内的 \(2\) 的叶子个数次方(表示总情况数),那么 \(cnt[u]-f[u]\)可能使得 \(dp[u]>W\)\(S\) 数量。

转移是容易写出的,我们根据当前点的深度奇偶性来讨论即可:

\[f[u]=\begin{cases}\prod_{v\in son_u} f[v]&dep[u]\bmod2=1\\ cnt[u]-\prod_{v\in son_u}(cnt[v]-f[v])& dep[u]\bmod 2=0 \end{cases} \]

我们同时也可以写出答案的表达式,设 \(sum\) 表示 \(S\) 的总数,可以用总数减去不合法的方案得出:

\[ans=sum-\prod_{dp[u]=W} ([dep[u]\%2=1]\cdot f[u]+[dep[u]\%2=0]\cdot (cnt[u]-f[u])) \]

注意如果 \(rt\) 是偶数那么 \(f[u]\) 的定义是可能使得 \(dp[u]>W\)\(S\) 数量,但是转移方式是完全一致的,只有初始值的设置不一样,所以这里就混为一谈了。

我们可以对转移有一个简化,考虑如果 \(dep[u]\) 是奇数那么 \(g[u]=cnt[u]-f[u]\);否则 \(g[u]=f[u]\),那么我们用 \(g[u]\) 来转移就会得到很简洁的形式:

\[g[u]=cnt[u]-\prod_{v\in son_u}g[v] \]

\[ans=sum-\prod_{dp[u]=W} (cnt[u]-g[u]) \]


上面是单个 \(k\) 的方法,接下来我们考虑解决多个 \(k\) 的情况。

考虑当 \(k\) 增大时某些叶子的初始值会发生变化。对应 \(rt\) 为奇数的叶子,如果原来 \(x<W\),那么根据贪心原则它最好变成 \(x>W\),也就是满足 \(x+k>W\) 的初始值需要设置为 \(f[x]=1\);对应 \(rt\) 为偶数的叶子,如果原来 \(x>W\),那么根据贪心原则它最好变成 \(x<W\),也就是满足 \(x-k<W\) 的初始值需要设置成 \(f[x]=1\)

每个叶子的初始值只会修改一次,而修改之后会影响到一整条链,所以我们可以使用动态 \(dp\),现在我们把 \(g\) 的转移写成这样的形式:

\[g[u]=cnt[u]+g[heavy[u]]\cdot (-\prod_{v\in light_u} g[v]) \]

所以我们对每个点维护 \(k=-\prod_{v\in light_u} g[v],b=cnt[u]\) 的一次函数即可,函数的合并就是 \((k_1\cdot k_2,k_1\cdot b_2+b_1)\)

还有就是因为本题在修改轻儿子\(/\)修改答案的时候可能会出现除 \(0\) 的情况,所以我们还要对某些信息维护出 \(0\) 的个数方便做出发,这里我是和一次函数绑在一起写的

这道题真的需要深入的理解才能写出代码,网上的代码基本上没有什么能看的,个人觉得我的代码还算清晰,各位可以借助我的代码来梳理思路,时间复杂度 \(O(n\log^2n)\)

#include <cstdio>
#include <vector>
#include <iostream>
using namespace std; 
const int M = 200005;
const int MOD = 998244353;
#define int long long
int read()
{
	int x=0,f=1;char c;
	while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
	while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
	return x*f;
}
int n,W,L,R,dep[M],dp[M],g[M],cnt[M],ans[M],yz[M];
int Ind,num[M],bot[M],fa[M],top[M],son[M],siz[M],rt[M];
vector<int> G[M];
int qkpow(int a,int b)
{
	int r=1;
	while(b>0)
	{
		if(b&1) r=r*a%MOD;
		a=a*a%MOD;
		b>>=1;
	}
	return r;
}
struct node
{
	int k,b;
	node(int K=1,int B=0) : k(K) , b(B) {}
	node operator + (const node &r) const
		{return node(k*r.k%MOD,(k*r.b+b)%MOD);}
	node operator / (const int &r) const
	{
		if(r==0) return node(k,b-1);
		return node(k*qkpow(r,MOD-2)%MOD,b);
	}
	node operator * (const int &r) const
	{
		if(r==0) return node(k,b+1);
		return node(k*r%MOD,b);
	}
	int val() {return b?0:k;}
	int get() {return (k+b)%MOD;}
}lg[M],t[M<<2],sum;
void dfs0(int u,int p)
{
	int mx=0,mi=n;fa[u]=p;
	dep[u]=dep[p]+1;yz[u]=1;
	for(int v:G[u]) if(v^p)
	{
		dfs0(v,u);yz[u]=0;
		mx=max(mx,dp[v]);
		mi=min(mi,dp[v]);
	}
	if(yz[u]) dp[u]=u;
	else dp[u]=(dep[u]&1)?mx:mi;
}
void dfs1(int u)
{
	siz[u]=1;
	for(int v:G[u]) if(v^fa[u] && dp[v]^W)
	{
		dfs1(v);siz[u]+=siz[v];
		if(siz[son[u]]<siz[v]) son[u]=v;
	}
}
void dfs2(int u,int tp)
{
	num[u]=++Ind;
	if(!rt[u]) rt[u]=rt[fa[u]]; 
	top[u]=tp;bot[u]=num[u];
	if(son[u]) dfs2(son[u],tp),bot[u]=bot[son[u]];
	for(int v:G[u]) if(v^fa[u] && v^son[u] && dp[v]^W)
		dfs2(v,v),lg[u]=lg[u]*g[v];
	//
	cnt[u]=1;
	for(int v:G[u]) if(v^fa[u] && dp[v]^W)
		cnt[u]=cnt[u]*cnt[v]%MOD;
	if(yz[u]) cnt[u]=2;
	//
	g[u]=((dep[u]&1)^(dp[u]<W))?cnt[u]:0;
	if(yz[u]) lg[u].k=g[u];
	else lg[u].k=MOD-lg[u].k;
}
void ins(int i,int l,int r,int id,int x)
{
	if(l==r)
	{
		if(yz[x]) t[i]=node(lg[x].val(),0);
		else t[i]=node(lg[x].val(),cnt[x]);
		return ;
	}
	int mid=(l+r)>>1;
	if(mid>=id) ins(i<<1,l,mid,id,x);
	else ins(i<<1|1,mid+1,r,id,x);
	t[i]=t[i<<1]+t[i<<1|1];
}
node ask(int i,int l,int r,int L,int R)
{
	if(L<=l && r<=R) return t[i];
	int mid=(l+r)>>1;
	if(mid<L) return ask(i<<1|1,mid+1,r,L,R);
	if(mid>=R) return ask(i<<1,l,mid,L,R);
	return ask(i<<1,l,mid,L,R)
	+ask(i<<1|1,mid+1,r,L,R);
}
void upd(int u)
{
	lg[u].k=1;int r=rt[u];
	node x=ask(1,1,n,num[r],bot[r]);
	sum=sum/(cnt[r]-x.get());
	while(dp[u]!=W)
	{
		node x=ask(1,1,n,num[top[u]],bot[u]);
		ins(1,1,n,num[u],u);
		node y=ask(1,1,n,num[top[u]],bot[u]);
		u=fa[top[u]];
		if(dep[u]<dep[r]) break;
		lg[u]=lg[u]/x.get();
		lg[u]=lg[u]*y.get();
	}
	ins(1,1,n,num[r],r);
	node y=ask(1,1,n,num[r],bot[r]);
	sum=sum*(cnt[r]-y.get());
}
signed main()
{
	n=read();L=read();R=read(); 
	sum=1;int m=1;
	for(int i=1;i<n;i++)
	{
		int u=read(),v=read();
		G[u].push_back(v);
		G[v].push_back(u);
	}
	dfs0(1,0);W=dp[1];
	for(int i=1;i<=n;i++)
		if(yz[i]) m=m*2%MOD;
	for(int u=1;u<=n;u++) if(dp[u]==W)
	{
		rt[u]=u;dfs1(u);dfs2(u,u);
		g[u]=yz[u]?1:0;
		sum=sum*(cnt[u]-g[u]);
	}
	for(int i=1;i<=n;i++)
		ins(1,1,n,num[i],i);
	for(int i=1;i<=n;i++)
	{
		ans[i]=(m-sum.val())%MOD;
		if(W+i<=n && yz[W+i] && dep[rt[W+i]]%2==0) upd(W+i);
		if(W-i>=2 && yz[W-i] && dep[rt[W-i]]%2==1) upd(W-i);
	}
	ans[n]=m-1;
	for(int i=n;i>=1;i--)
		ans[i]=(ans[i]-ans[i-1])%MOD;
	for(int i=L;i<=R;i++)
		printf("%lld ",(ans[i]+MOD)%MOD);
	puts("");
}
posted @ 2022-03-02 21:04  C202044zxy  阅读(133)  评论(2编辑  收藏  举报