[被踩计划] 题解 [NOI2020]命运
[被踩计划] 题解 [NOI2020]命运
为什么叫被踩记录呢?因为感觉自己之前真的是太菜了,打算把之前联赛等考过的题目做一做,看看自已以前有多菜,所以取名叫被踩记录。
题意简述
给定一棵 \(n\) 个点的有根树,同时给定 \(m\) 条链 \((u,v)\) (保证 \(u\) 是 \(v\) 的祖先),询问有多少种给每条边赋权 \([0,1]\) 的方案使得每条链都至少有一条边边权为 \(1\) ,答案对 \(998244353\) 取模。
\(1\le n,m\le 5\times 10^5\) ,时限 2s ,空间限制 1GB 。
题目分析
考场上傻乎乎的我想了个容斥,然后就被踩了。
据说这道题目是套路题,自己做的题目果然还是太少了。
不难发现,一个点 \(u\) 往上的所有链中我们只关心另一个端点最深的那条链,因为如果这条链被覆盖了,那么其它的链都会被覆盖。
设 \(dp(u,i)\) 表示考虑完了以 \(u\) 为根的整个子树中的所有边的取值情况,除了跨越 \(u\) 节点的链以外其它链都被覆盖,并且其中跨越了 \(u\) 节点的没有被覆盖的链中最深的链另一个端点的深度为 \(i\) 的方案数。特殊的,其中 \(dp(u,0)\) 表示没有跨越 \(u\) 节点的未被覆盖的链的方案数。 \(dp(1,0)\) 就是答案。
转移就是考虑子树合并,考虑将子树 \(v\) 并入子树 \(u\) ,那么转移就是枚举 \(u-v\) 的权值,设转移后得到的数组为 \(f()\) ,取值为 \(0\) 时:
取值为 \(1\) 时:
总的方程:
设前缀和数组为 \(g(u,i)\) ,则:
这样如果直接 dp ,那么时间复杂度就是 \(\mathcal O(n^2)\) 。
由于 \(dp(u,i)\) 的第二维只有若干个位置是有值的(值非 \(0\) ),这些位置可以认为是它自己向上的链加上它子树内向上的链,所以可以使用线段树来维护所有非 \(0\) 的位置的值,而 dp 合并的时候就采用线段树合并来实现转移。
具体如何实现?观察转移方程,发现有两个前缀和数组,可以在线段树合并的过程中记录前缀和以供转移。如果出现了某一个节点代表的位置只有 \(dp(u,l\sim r)\) 有值,那么转移方程可以认为是 \(f(l\sim r)=dp(u,l\sim r)(g(v,i)+g(v,\operatorname{deep}_v-1))\) ,只需要实现区间乘即可;如果出现了某一个节点代表的位置只有 \(dp(v,l\sim r)\) 有值,那么转移方程可以认为是 \(f(l\sim r)=dp(v,l\sim r)g(u,l-1)\) ,也只需要实现区间乘即可。如果合并到了叶子节点,就直接按 dp 转移方程合并即可。
还有一点需要注意的是,从儿子节点合并上来的线段树可能有若干个位置是非法的(当 \(dp(u,i)\) 满足 \(i\ge \operatorname{deep}_u\) 时,这个状态时非法的),要把非法状态去掉,就需要再进行一遍区间赋 \(0\) (区间乘以 \(0\) )。
总的时间复杂度是 \(\mathcal O(n\log_2n)\) 。
参考代码
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ch() getchar()
#define pc(x) putchar(x)
using namespace std;
template<typename T>void read(T&x){
static char c;static int f;
for(c=ch(),f=1;c<'0'||c>'9';c=ch())if(c=='-')f=-f;
for(x=0;c>='0'&&c<='9';c=ch())x=x*10+(c&15);x*=f;
}
template<typename T>void write(T x){
static char q[65];int cnt=0;
if(x<0)pc('-'),x=-x;
q[++cnt]=x%10,x/=10;
while(x)
q[++cnt]=x%10,x/=10;
while(cnt)pc(q[cnt--]+'0');
}
const int mod=998244353,maxn=500005;
int mo(const int x){
return x>=mod?x-mod:x;
}
struct Edge{
int v,nt;
Edge(int v=0,int nt=0):
v(v),nt(nt){}
}e[maxn*2];
int hd[maxn],num;
void qwq(int u,int v){
e[++num]=Edge(v,hd[u]),hd[u]=num;
}
int dp[maxn],mx[maxn];
void dfs(int u,int fa){
dp[u]=dp[fa]+1;
for(int i=hd[u];i;i=e[i].nt){
int v=e[i].v;
if(v==fa)continue;
dfs(v,u);
}
}
struct Node{
int l,r,sum,mul;
Node(int l=0,int r=0,int sum=0,int mul=1):
l(l),r(r),sum(sum),mul(mul){}
}P[maxn*25];
int tot;
int build(int l,int r,int p){
int re=++tot,mid=(l+r)>>1;
P[re]=Node(0,0,1);
if(l==r)return re;
if(p<=mid)P[re].l=build(l,mid,p);
else P[re].r=build(mid+1,r,p);
return re;
}
void push(int x,int mul){
if(!x)return;
P[x].sum=1ll*P[x].sum*mul%mod;
P[x].mul=1ll*P[x].mul*mul%mod;
}
void pushdown(int x){
if(P[x].mul==1)return;
push(P[x].l,P[x].mul);
push(P[x].r,P[x].mul);
P[x].mul=1;
}
void pushup(int x){
P[x].sum=mo(P[P[x].l].sum+P[P[x].r].sum);
}
int Merge(int x,int y,int smu,int smv){
if(!x)return push(y,smu),y;
if(!y)return push(x,smv),x;
pushdown(x);pushdown(y);
if(!P[x].l&&!P[x].r)
P[x].sum=mo(1ll*P[x].sum*mo(smv+P[y].sum)%mod+1ll*P[y].sum*smu%mod);
else{
P[x].r=Merge(P[x].r,P[y].r,mo(smu+P[P[x].l].sum),mo(smv+P[P[y].l].sum));
P[x].l=Merge(P[x].l,P[y].l,smu,smv);pushup(x);
}
return x;
}
void cover(int x,int l,int r,int L,int R){
if(!x||(L<=l&&r<=R))return push(x,0);
pushdown(x);int mid=(l+r)>>1;
if(L<=mid)cover(P[x].l,l,mid,L,R);
if(R>mid)cover(P[x].r,mid+1,r,L,R);
return pushup(x);
}
int n,rt[maxn];
void print(int x,int l,int r){
if(l==r)return write(P[x].sum),pc(" \n"[r==n]),void();
pushdown(x);int mid=(l+r)>>1;
print(P[x].l,l,mid);print(P[x].r,mid+1,r);
}
void solve(int u,int fa){
rt[u]=build(0,n,mx[u]);
for(int i=hd[u];i;i=e[i].nt){
int v=e[i].v;if(v==fa)continue;solve(v,u);
rt[u]=Merge(rt[u],rt[v],0,P[rt[v]].sum);
}
cover(rt[u],0,n,dp[u],n);
}
int query(int x,int l,int r){
if(l==r)return P[x].sum;
pushdown(x);int mid=(l+r)>>1;
return query(P[x].l,l,mid);
}
int main(){
read(n);
for(int i=2;i<=n;++i){
int u,v;
read(u),read(v);
qwq(u,v),qwq(v,u);
}
dfs(1,0);
int m;read(m);
for(int i=1;i<=m;++i){
int u,v;
read(u),read(v);
mx[v]=max(mx[v],dp[u]);
}
solve(1,0);
write(query(rt[1],0,n)),pc('\n');
return 0;
}