[NOI2020] 命运
一、题目
二、解法
有一个比较重要的性质:对于同一个 \(v\) 我们只需要取最深的 \(u\) 去考虑即可,而且可以在 \(v\) 处处理限制 \((u,v)\),但是我们可能并不会现在就解决这个限制,可能要留到祖先去解决,这正好符合我们树形 \(dp\) 留一部分问题留给祖先考虑的特征。
设 \(dp[i][j]\) 表示最多不合法的向上延伸到了深度为 \(j\) 的祖先,其他都合法的方案数。
转移就一个一个子树地合并上去,相同的限制就取深的:
- 如果这条边选为 \(1\),清除儿子的不合法记号,如果比 \(dep_u\) 还大的 \(dp[v][i]\) 是不可能在祖先那里被解决的,所以不能统计:\(dp'[u][j]=dp[u][j]\times(\sum_{i=0}^{dep_u} dp[v][i])\)
- 如果这条边选为 \(0\),那么记号合并上来,我们讨论一下两者的大小关系:\(dp'[u][i]=dp[u][i]\times(\sum_{j\leq i}dp[v][j])\) 或者是 \(dp'[u][j]=(\sum_{i<j}dp[u][i])\times dp[v][j]\)
那么直接线段树合并就好了?也方便求和。第一种转移相当于一个全局的乘法,先不忙算,在线段树合并处理第二种转移的时候顺便算一下就行了,还有这道题坑点是真的多,一定要保持头脑清醒。时间复杂度 \(O(n\log n)\)
#include <cstdio>
#include <iostream>
#include <cstdlib>
using namespace std;
const int M = 500005;
const int MOD = 998244353;
#define int long long
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,m,tot,tmp,f[M],mx[M],rt[M],dep[M];
int cnt,tg[20*M],dp[20*M],ls[20*M],rs[20*M];
struct edge
{
int v,next;
edge(int V=0,int N=0) : v(V) , next(N) {}
}e[2*M];
void pre(int u,int fa)
{
dep[u]=dep[fa]+1;
for(int i=f[u];i;i=e[i].next)
{
int v=e[i].v;
if(v==fa) continue;
pre(v,u);
}
}
void ins(int &x,int l,int r,int f)
{
x=++cnt;
dp[x]=tg[x]=1;
if(l==r) return ;
int mid=(l+r)>>1;
if(mid>=f) ins(ls[x],l,mid,f);
else ins(rs[x],mid+1,r,f);
}
void mul(int x,int y)
{
dp[x]=dp[x]*y%MOD;
tg[x]=tg[x]*y%MOD;
}
void down(int x)
{
if(tg[x]!=1)//错过了
{
mul(ls[x],tg[x]);
mul(rs[x],tg[x]);
tg[x]=1;
}
}
int merge(int x,int y,int l,int r,int s1,int s2)
//s1表示dp[v][j]的求和,累加右子树
//s2表示dp[u][i]的求和,也是累加右子树
{
if(!x && !y) return 0;
if(!x)//此时考虑一下第二种转移
{
mul(y,s2);
return y;
}
if(!y)//考虑第一种转移
{
mul(x,s1+tmp);
return x;
}
if(l==r)
{
//printf("%d %d %d\n",x,s1,s2);
dp[x]=(dp[x]*(s1+dp[y]+tmp)%MOD+dp[y]*s2)%MOD;
return x;
}
int mid=(l+r)>>1;
down(x);down(y);
//那个傻逼东西会改,先访问右儿子
rs[x]=merge(rs[x],rs[y],mid+1,r,(s1+dp[ls[y]])%MOD,(s2+dp[ls[x]])%MOD);
ls[x]=merge(ls[x],ls[y],l,mid,s1,s2);
dp[x]=(dp[ls[x]]+dp[rs[x]])%MOD;
return x;
}
int find(int x,int l,int r,int L,int R)
{
if(L>r || l>R) return 0;
if(L<=l && r<=R) return dp[x];
int mid=(l+r)>>1;down(x);
return find(rs[x],mid+1,r,L,R)+find(ls[x],l,mid,L,R);
}
void dfs(int u,int fa)
{
ins(rt[u],0,n,mx[u]);
for(int i=f[u];i;i=e[i].next)
{
int v=e[i].v;
if(v==fa) continue;
dfs(v,u);
tmp=find(rt[v],0,n,0,dep[u]);//错过了
rt[u]=merge(rt[u],rt[v],0,n,0,0);
}
}
signed main()
{
n=read();
for(int i=1;i<n;i++)
{
int u=read(),v=read();
e[++tot]=edge(v,f[u]),f[u]=tot;
e[++tot]=edge(u,f[v]),f[v]=tot;
}
pre(1,0);
m=read();
for(int i=1;i<=m;i++)
{
int u=read(),v=read();
mx[v]=max(mx[v],dep[u]);
}
dfs(1,0);
printf("%d\n",find(rt[1],0,n,0,0));
}