[ZJOI2019] Minimax搜索
一、题目
二、解法
\(\tt md\) 这题真的把我心态整炸了,真的太神了,理解都搞了整整一个晚上。
注意本题只需要改变根节点的值,我们可以预处理出 \(dp[u]\) 表示 \(u\) 节点最初的权值,然后设 \(W=dp[1]\),考虑如果 \(W\) 在 \(S\) 中那么代价一定是 \(1\),这种情况是平凡的。
第一个转化是我们本来想求 \(\max_{i\in S}|i-w_i|=k\) 的方案数,但是我们可以差分一下求出 \(\max_{i\in S}|i-w_i|\leq k\) 的方案数,也就是 \(\forall |i-w_i|\leq k\),这样就变成了关于每个点的限制。
考虑 \(dp[u]=W\) 的 \(u\) 一定构成原树上的一条链,那么我们只要改变这条链上任意一个节点的值就可以改变根,那么我们可以把这条链断开,对于每个连通块单独讨论,然后用乘法原理合并即可。
继续考虑某个点 \(x\) 和 \(W\) 链对应的根是 \(rt\),如果 \(rt\) 的深度是奇数,那么我们只有 \(x<W\) 把变成 \(x>W\) 才是可能有用的,\(x>W\) 不需要变化;如果 \(rt\) 的深度是偶数,那么我们只有 \(x>W\) 把变成 \(x<W\) 才是可能有用,\(x<W\) 不需要变化(可以把 \(<W\) 看成 \(0\),\(>W\) 看成 \(1\) 来理解这个结论)。
那么我们枚举 \(k\) 并且知道 \(S\) 之后可以贪心地确定每个点的取值。现在回到计数问题上来,考虑 \(rt\) 是奇数的情况,我们设 \(f[u]\) 表示只能使得 \(dp[u]<W\) 的 \(S\) 数量,设 \(cnt[u]\) 表示 \(u\) 节点以内的 \(2\) 的叶子个数次方(表示总情况数),那么 \(cnt[u]-f[u]\) 是可能使得 \(dp[u]>W\) 的 \(S\) 数量。
转移是容易写出的,我们根据当前点的深度奇偶性来讨论即可:
我们同时也可以写出答案的表达式,设 \(sum\) 表示 \(S\) 的总数,可以用总数减去不合法的方案得出:
注意如果 \(rt\) 是偶数那么 \(f[u]\) 的定义是可能使得 \(dp[u]>W\) 的 \(S\) 数量,但是转移方式是完全一致的,只有初始值的设置不一样,所以这里就混为一谈了。
我们可以对转移有一个简化,考虑如果 \(dep[u]\) 是奇数那么 \(g[u]=cnt[u]-f[u]\);否则 \(g[u]=f[u]\),那么我们用 \(g[u]\) 来转移就会得到很简洁的形式:
上面是单个 \(k\) 的方法,接下来我们考虑解决多个 \(k\) 的情况。
考虑当 \(k\) 增大时某些叶子的初始值会发生变化。对应 \(rt\) 为奇数的叶子,如果原来 \(x<W\),那么根据贪心原则它最好变成 \(x>W\),也就是满足 \(x+k>W\) 的初始值需要设置为 \(f[x]=1\);对应 \(rt\) 为偶数的叶子,如果原来 \(x>W\),那么根据贪心原则它最好变成 \(x<W\),也就是满足 \(x-k<W\) 的初始值需要设置成 \(f[x]=1\)
每个叶子的初始值只会修改一次,而修改之后会影响到一整条链,所以我们可以使用动态 \(dp\),现在我们把 \(g\) 的转移写成这样的形式:
所以我们对每个点维护 \(k=-\prod_{v\in light_u} g[v],b=cnt[u]\) 的一次函数即可,函数的合并就是 \((k_1\cdot k_2,k_1\cdot b_2+b_1)\)
还有就是因为本题在修改轻儿子\(/\)修改答案的时候可能会出现除 \(0\) 的情况,所以我们还要对某些信息维护出 \(0\) 的个数方便做出发,这里我是和一次函数绑在一起写的。
这道题真的需要深入的理解才能写出代码,网上的代码基本上没有什么能看的,个人觉得我的代码还算清晰,各位可以借助我的代码来梳理思路,时间复杂度 \(O(n\log^2n)\)
#include <cstdio>
#include <vector>
#include <iostream>
using namespace std;
const int M = 200005;
const int MOD = 998244353;
#define int long long
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,W,L,R,dep[M],dp[M],g[M],cnt[M],ans[M],yz[M];
int Ind,num[M],bot[M],fa[M],top[M],son[M],siz[M],rt[M];
vector<int> G[M];
int qkpow(int a,int b)
{
int r=1;
while(b>0)
{
if(b&1) r=r*a%MOD;
a=a*a%MOD;
b>>=1;
}
return r;
}
struct node
{
int k,b;
node(int K=1,int B=0) : k(K) , b(B) {}
node operator + (const node &r) const
{return node(k*r.k%MOD,(k*r.b+b)%MOD);}
node operator / (const int &r) const
{
if(r==0) return node(k,b-1);
return node(k*qkpow(r,MOD-2)%MOD,b);
}
node operator * (const int &r) const
{
if(r==0) return node(k,b+1);
return node(k*r%MOD,b);
}
int val() {return b?0:k;}
int get() {return (k+b)%MOD;}
}lg[M],t[M<<2],sum;
void dfs0(int u,int p)
{
int mx=0,mi=n;fa[u]=p;
dep[u]=dep[p]+1;yz[u]=1;
for(int v:G[u]) if(v^p)
{
dfs0(v,u);yz[u]=0;
mx=max(mx,dp[v]);
mi=min(mi,dp[v]);
}
if(yz[u]) dp[u]=u;
else dp[u]=(dep[u]&1)?mx:mi;
}
void dfs1(int u)
{
siz[u]=1;
for(int v:G[u]) if(v^fa[u] && dp[v]^W)
{
dfs1(v);siz[u]+=siz[v];
if(siz[son[u]]<siz[v]) son[u]=v;
}
}
void dfs2(int u,int tp)
{
num[u]=++Ind;
if(!rt[u]) rt[u]=rt[fa[u]];
top[u]=tp;bot[u]=num[u];
if(son[u]) dfs2(son[u],tp),bot[u]=bot[son[u]];
for(int v:G[u]) if(v^fa[u] && v^son[u] && dp[v]^W)
dfs2(v,v),lg[u]=lg[u]*g[v];
//
cnt[u]=1;
for(int v:G[u]) if(v^fa[u] && dp[v]^W)
cnt[u]=cnt[u]*cnt[v]%MOD;
if(yz[u]) cnt[u]=2;
//
g[u]=((dep[u]&1)^(dp[u]<W))?cnt[u]:0;
if(yz[u]) lg[u].k=g[u];
else lg[u].k=MOD-lg[u].k;
}
void ins(int i,int l,int r,int id,int x)
{
if(l==r)
{
if(yz[x]) t[i]=node(lg[x].val(),0);
else t[i]=node(lg[x].val(),cnt[x]);
return ;
}
int mid=(l+r)>>1;
if(mid>=id) ins(i<<1,l,mid,id,x);
else ins(i<<1|1,mid+1,r,id,x);
t[i]=t[i<<1]+t[i<<1|1];
}
node ask(int i,int l,int r,int L,int R)
{
if(L<=l && r<=R) return t[i];
int mid=(l+r)>>1;
if(mid<L) return ask(i<<1|1,mid+1,r,L,R);
if(mid>=R) return ask(i<<1,l,mid,L,R);
return ask(i<<1,l,mid,L,R)
+ask(i<<1|1,mid+1,r,L,R);
}
void upd(int u)
{
lg[u].k=1;int r=rt[u];
node x=ask(1,1,n,num[r],bot[r]);
sum=sum/(cnt[r]-x.get());
while(dp[u]!=W)
{
node x=ask(1,1,n,num[top[u]],bot[u]);
ins(1,1,n,num[u],u);
node y=ask(1,1,n,num[top[u]],bot[u]);
u=fa[top[u]];
if(dep[u]<dep[r]) break;
lg[u]=lg[u]/x.get();
lg[u]=lg[u]*y.get();
}
ins(1,1,n,num[r],r);
node y=ask(1,1,n,num[r],bot[r]);
sum=sum*(cnt[r]-y.get());
}
signed main()
{
n=read();L=read();R=read();
sum=1;int m=1;
for(int i=1;i<n;i++)
{
int u=read(),v=read();
G[u].push_back(v);
G[v].push_back(u);
}
dfs0(1,0);W=dp[1];
for(int i=1;i<=n;i++)
if(yz[i]) m=m*2%MOD;
for(int u=1;u<=n;u++) if(dp[u]==W)
{
rt[u]=u;dfs1(u);dfs2(u,u);
g[u]=yz[u]?1:0;
sum=sum*(cnt[u]-g[u]);
}
for(int i=1;i<=n;i++)
ins(1,1,n,num[i],i);
for(int i=1;i<=n;i++)
{
ans[i]=(m-sum.val())%MOD;
if(W+i<=n && yz[W+i] && dep[rt[W+i]]%2==0) upd(W+i);
if(W-i>=2 && yz[W-i] && dep[rt[W-i]]%2==1) upd(W-i);
}
ans[n]=m-1;
for(int i=n;i>=1;i--)
ans[i]=(ans[i]-ans[i-1])%MOD;
for(int i=L;i<=R;i++)
printf("%lld ",(ans[i]+MOD)%MOD);
puts("");
}