LOJ3340 「NOI2020」命运
博主有幸参加了NOI2020,考场上的经历和心得请见这篇文章。这里就不唠叨了。
本题题解
为了方便表述,我们把边重要、不重要,称为:颜色为\(1\)、\(0\)。本题相当于要求对给边染色的方案计数。
前\(40\)分部分分,是对\(m\)个限制做容斥。也就是\(2^m\)枚举哪些路径强制设为\(0\),不在这些路径上的边可以随意。问题转化为求路径并,可以用树链剖分或虚树维护。由于与正解做法关联不大,这里不细讲了。
正解的起点是一个DP。设\(dp[u][i]\)表示假设\(u\)的子树外边的颜色都已经确定,且\(u\)的“祖先边”(\(u\)到根路径上所有边)中,深度最大(离它最近)的颜色为\(1\)的边,深度为\(i\),此时\(u\)的子树内所有边合法的染色方案数。
容易发现,一个子树内的染色方案数,只和这个子树“祖先边”中深度最大的颜色为\(1\)的边的深度有关,而这个东西我们已经写在了状态里(就是\(i\))。因此这个状态非常合理。
更具体来讲,我们定义点的深度为它到根路径上的边数。边的深度为它连接的儿子节点的深度。因为根节点(深度为\(0\))没有父亲边,所以不存在深度为\(0\)的边。不过我们可以假装有这样一条边,且这条边的颜色永远是\(1\)。则DP的答案就是\(dp[1][0]\),其中\(0\)就是我们假想出来的这条边的深度。
至于\(m\)条限制,我们对每个点\(u\),维护它“祖先边”中,离它最近的、颜色为\(1\)的边,深度至少是多少,记为\(\text{lim}[u]\)。初始时,\(\text{lim}[u]=0\),也就是我们假想的那个颜色永远是\(1\)的边。每读入一条限制\((u,v)\),就令\(\text{lim}[v]:=\max(\text{lim}[v],\text{dep}[u]+1)\)。同时,每个节点的\(\text{lim}\),还要对\(\text{lim}[fa(u)]\)取\(\max\)。
容易发现,对于\(i\notin [\text{lim}[u],\text{dep}[u]]\),\(dp[u][i]=0\)。
对于其他的\(i\),可以从儿子转移过来。不难写出:
这样直接DP是\(O(n^2)\)的。
继续优化,对于每个点\(u\),考虑把\(dp[u]\)的第二维搬到线段树上,然后做线段树合并。考虑在转移过程中,我们的线段树需要做哪些操作。
- 假设线段树初值全部为\(0\)。那我们要先把\([\text{lim}[u],\text{dep}[u]]\)这个区间赋值为\(1\)。
- 我们要对所有\(0\leq i<\text{dep}[u]\),令\(dp[u][i]\texttt{+=}dp[u][\text{dep[u]}]\),这是为了从\(u\)向父亲转移时更方便(看一下转移式就知道了)。
- 在第2条的基础上,转移时,我相当于要合并两棵线段树,并实现把所有对应位置相乘。
前两条操作,可以看成区间加。第3条操作又涉及乘法。所以我们的线段树,维护区间加法和乘法两个懒标记。
这里需要注意一下,下放标记的顺序。更新乘法标记时,把已有的加法标记也乘上这个数。更新加法标记时则不动乘法标记。下放标记(push_down
)时,先下放乘法标记,再下放加法标记。这里面的原因并不复杂,可以自行思考一下。
同时我们还要支持单点查询。这个在访问到线段树叶子节点时,直接反回它的加法标记值即可。
这样合并\(O(n)\)次,做\(O(n)\)次修改,时间复杂度是\(O(n\log n)\)的。
参考代码(建议使用读入优化,详见本博客公告):
//problem:LOJ3340
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
template<typename T>inline void ckmax(T& x,T y){x=(y>x?y:x);}
template<typename T>inline void ckmin(T& x,T y){x=(y<x?y:x);}
const int MAXN=5e5,MAXM=5e5;
const int MOD=998244353;
inline int mod1(int x){return x<MOD?x:x-MOD;}
inline int mod2(int x){return x<0?x+MOD:x;}
inline void add(int& x,int y){x=mod1(x+y);}
inline void sub(int& x,int y){x=mod2(x-y);}
inline int pow_mod(int x,int i){int y=1;while(i){if(i&1)y=(ll)y*x%MOD;x=(ll)x*x%MOD;i>>=1;}return y;}
int n,m;
vector<int>G[MAXN+5];
struct Request{
// u是祖先,v是后代
int u,v;
}q[MAXM+5];
int fa[MAXN+5],dep[MAXN+5];
int lim[MAXN+5];
void dfs_prep(int u){
for(int i=0;i<SZ(G[u]);++i){
int v=G[u][i];
if(v==fa[u]) continue;
fa[v]=u;
dep[v]=dep[u]+1;
dfs_prep(v);
}
}
struct SegmentTree{
int rt[MAXN+5];
int cnt;
vector<int>bin;
int ls[MAXN*100+5],rs[MAXN*100+5];
int tag_add[MAXN*100+5],tag_mul[MAXN*100+5];
int new_node(int _mul=1,int _add=0){
int x;
if(SZ(bin)){
x=bin.back();
bin.pop_back();
}
else{
x=++cnt;
}
ls[x]=rs[x]=0;
tag_mul[x]=_mul;
tag_add[x]=_add;
return x;
}
void upd_mul(int x,int v){
tag_mul[x]=(ll)tag_mul[x]*v%MOD;
tag_add[x]=(ll)tag_add[x]*v%MOD;
}
void upd_add(int x,int v){
add(tag_add[x],v);
}
void push_down(int x){
if(tag_mul[x]!=1){
if(!ls[x]) ls[x]=new_node();
if(!rs[x]) rs[x]=new_node();
upd_mul(ls[x],tag_mul[x]);
upd_mul(rs[x],tag_mul[x]);
tag_mul[x]=1;
}
if(tag_add[x]){
if(!ls[x]) ls[x]=new_node();
if(!rs[x]) rs[x]=new_node();
upd_add(ls[x],tag_add[x]);
upd_add(rs[x],tag_add[x]);
tag_add[x]=0;
}
}
void modify_add(int& x,int l,int r,int ql,int qr,int v){
if(ql<=l && qr>=r){
upd_add(x,v);
return;
}
int mid=(l+r)>>1;
push_down(x);
if(ql<=mid)
modify_add(ls[x],l,mid,ql,qr,v);
if(qr>mid)
modify_add(rs[x],mid+1,r,ql,qr,v);
}
int query(int x,int l,int r,int pos){
if(l==r){
return tag_add[x];
}
int mid=(l+r)>>1;
push_down(x);
if(pos<=mid)
return query(ls[x],l,mid,pos);
else
return query(rs[x],mid+1,r,pos);
}
int merge(int x, int y){
if (!ls[x] && !rs[x]) swap(x, y);
if (!ls[y] && !rs[y]){
upd_mul(x,tag_add[y]);
return x;
}
push_down(x), push_down(y);
ls[x] = merge(ls[x], ls[y]);
rs[x] = merge(rs[x], rs[y]);
return x;
}
SegmentTree(){}
}T;
/*
int dp[MAXN+5][MAXN+5];
void dfs_dp(int u){
for(int j=lim[u];j<=dep[u];++j){
dp[u][j]=1;
}
for(int i=0;i<SZ(G[u]);++i){
int v=G[u][i];
if(v==fa[u]) continue;
dfs_dp(v);
for(int j=lim[u];j<=dep[u];++j){
dp[u][j]=(ll)dp[u][j]*(dp[v][j]+dp[v][dep[v]])%MOD;
}
}
}
*/
void dfs_dp(int u){
lim[u]=max(lim[u],lim[fa[u]]);
T.rt[u]=T.new_node(0,0);
T.modify_add(T.rt[u],0,n-1,lim[u],dep[u],1);
for(int i=0;i<SZ(G[u]);++i){
int v=G[u][i];
if(v==fa[u]) continue;
dfs_dp(v);
T.rt[u]=T.merge(T.rt[u],T.rt[v]);
}
if(u==1) return;
int v=T.query(T.rt[u],0,n-1,dep[u]);
T.modify_add(T.rt[u],0,n-1,0,dep[u]-1,v);
}
int main() {
// freopen("destiny.in","r",stdin);
// freopen("destiny.out","w",stdout);
cin>>n;
for(int i=1;i<n;++i){
int u,v;
cin>>u>>v;
G[u].pb(v);
G[v].pb(u);
}
dfs_prep(1);
cin>>m;
for(int i=1;i<=m;++i){
cin>>q[i].u>>q[i].v;
ckmax(lim[q[i].v],dep[q[i].u]+1);
}
dfs_dp(1);
cout<<T.query(T.rt[1],0,n-1,0)<<endl;
return 0;
}