[NOI2020] 命运

一、题目

点此看题

二、解法

有一个比较重要的性质:对于同一个 \(v\) 我们只需要取最深的 \(u\) 去考虑即可,而且可以在 \(v\) 处处理限制 \((u,v)\),但是我们可能并不会现在就解决这个限制,可能要留到祖先去解决,这正好符合我们树形 \(dp\) 留一部分问题留给祖先考虑的特征。

\(dp[i][j]\) 表示最多不合法的向上延伸到了深度为 \(j\) 的祖先,其他都合法的方案数。

转移就一个一个子树地合并上去,相同的限制就取深的:

  • 如果这条边选为 \(1\),清除儿子的不合法记号,如果比 \(dep_u\) 还大的 \(dp[v][i]\) 是不可能在祖先那里被解决的,所以不能统计:\(dp'[u][j]=dp[u][j]\times(\sum_{i=0}^{dep_u} dp[v][i])\)
  • 如果这条边选为 \(0\),那么记号合并上来,我们讨论一下两者的大小关系:\(dp'[u][i]=dp[u][i]\times(\sum_{j\leq i}dp[v][j])\) 或者是 \(dp'[u][j]=(\sum_{i<j}dp[u][i])\times dp[v][j]\)

那么直接线段树合并就好了?也方便求和。第一种转移相当于一个全局的乘法,先不忙算,在线段树合并处理第二种转移的时候顺便算一下就行了,还有这道题坑点是真的多,一定要保持头脑清醒。时间复杂度 \(O(n\log n)\)

#include <cstdio>
#include <iostream>
#include <cstdlib>
using namespace std;
const int M = 500005;
const int MOD = 998244353;
#define int long long
int read()
{
	int x=0,f=1;char c;
	while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
	while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
	return x*f;
}
int n,m,tot,tmp,f[M],mx[M],rt[M],dep[M];
int cnt,tg[20*M],dp[20*M],ls[20*M],rs[20*M];
struct edge
{
	int v,next;
	edge(int V=0,int N=0) : v(V) , next(N) {}
}e[2*M];
void pre(int u,int fa)
{
	dep[u]=dep[fa]+1;
	for(int i=f[u];i;i=e[i].next)
	{
		int v=e[i].v;
		if(v==fa) continue;
		pre(v,u);
	}
}
void ins(int &x,int l,int r,int f)
{
	x=++cnt;
	dp[x]=tg[x]=1;
	if(l==r) return ;
	int mid=(l+r)>>1;
	if(mid>=f) ins(ls[x],l,mid,f);
	else ins(rs[x],mid+1,r,f);
}
void mul(int x,int y)
{
	dp[x]=dp[x]*y%MOD;
	tg[x]=tg[x]*y%MOD;
}
void down(int x)
{
	if(tg[x]!=1)//错过了 
	{
		mul(ls[x],tg[x]);
		mul(rs[x],tg[x]);
		tg[x]=1;
	}
}
int merge(int x,int y,int l,int r,int s1,int s2)
//s1表示dp[v][j]的求和,累加右子树 
//s2表示dp[u][i]的求和,也是累加右子树 
{
	if(!x && !y) return 0;
	if(!x)//此时考虑一下第二种转移 
	{
		mul(y,s2);
		return y;
	}
	if(!y)//考虑第一种转移
	{
		mul(x,s1+tmp);
		return x;
	}
	if(l==r)
	{
		//printf("%d %d %d\n",x,s1,s2);
		dp[x]=(dp[x]*(s1+dp[y]+tmp)%MOD+dp[y]*s2)%MOD;
		return x;
	}
	int mid=(l+r)>>1;
	down(x);down(y);
	//那个傻逼东西会改,先访问右儿子 
	rs[x]=merge(rs[x],rs[y],mid+1,r,(s1+dp[ls[y]])%MOD,(s2+dp[ls[x]])%MOD);
	ls[x]=merge(ls[x],ls[y],l,mid,s1,s2);
	dp[x]=(dp[ls[x]]+dp[rs[x]])%MOD;
	return x;
}
int find(int x,int l,int r,int L,int R)
{
	if(L>r || l>R) return 0;
	if(L<=l && r<=R) return dp[x];
	int mid=(l+r)>>1;down(x);
	return find(rs[x],mid+1,r,L,R)+find(ls[x],l,mid,L,R);
}
void dfs(int u,int fa)
{
	ins(rt[u],0,n,mx[u]);
	for(int i=f[u];i;i=e[i].next)
	{
		int v=e[i].v;
		if(v==fa) continue;
		dfs(v,u);
		tmp=find(rt[v],0,n,0,dep[u]);//错过了 
		rt[u]=merge(rt[u],rt[v],0,n,0,0);
	}
}
signed main()
{
	n=read();
	for(int i=1;i<n;i++)
	{
		int u=read(),v=read();
		e[++tot]=edge(v,f[u]),f[u]=tot;
		e[++tot]=edge(u,f[v]),f[v]=tot;
	}
	pre(1,0);
	m=read();
	for(int i=1;i<=m;i++)
	{
		int u=read(),v=read();
		mx[v]=max(mx[v],dep[u]);
	}
	dfs(1,0);
	printf("%d\n",find(rt[1],0,n,0,0));
}
posted @ 2021-03-22 21:52  C202044zxy  阅读(105)  评论(0编辑  收藏  举报