P9208 虚人「无」

Question 问题 P9208 虚人「无」

给定二元序列 \(\{(v_i,c_i)\}\) 和一棵以 \(1\) 为根的有根树。第 \(i\) 个点的点权是 \((v_i,c_i)\)

  • 定义一个非根节点的权值为其子树内的 \(c\) 的积乘上其子树补的 \(v\) 的积。
  • 定义一个根节点的权值为其子树内的 \(c\) 的积。

形式化的讲,若 \(u\) 不为根节点,则 \(u\) 的权值 \(f_u\) 为:

\[f_u=\prod\limits_{v\in \operatorname{substree}(u)} c_v\times \prod\limits_{v\notin \operatorname{substree}(u)} v_v \]

否则,其权值 \(f_u\) 为:

\[f_u=\prod\limits_{v=1}^n c_v \]

试求整棵树所有节点的权值之和,答案对 \(m\) 取模。请注意:保证 \(\bm m\) 是质数

Analysis 分析

定义:

\(dfn_x\)\(x\)\(dfs\) 序。

\(siz_x\)\(x\) 的子树大小。

先讲正解:dfs 序 & 线段树查询

我们跑一遍 \(dfs\) 序之后,思考对于一个点 \(x\),其子节点所在的 \(dfs\) 序必然在区间 \(dfn_{x} \sim dfn_{x}+siz_{x}-1\) 之中。

接下来我们只需要维护区间内 \(c,v\) 的乘积,显然我们只需要用线段树维护一下乘积就行了。

以下是错解:逆元(可选择不看)

考虑一个非常自然就能想到的方法,看到取模,想用逆元。但是在本题不行。因为使用逆元需要当所有参与计算答案的值均存在逆元,即模数 \(m \perp c_i,v_i\)

所以这道题可以通过特殊构造卡掉逆元做法。

Code 代码

#include<bits/stdc++.h>
#define rint register int
#define ls (p<<1)
#define rs (p<<1|1)
using namespace std;
typedef long long ll;
const int N=3e5+8;
int n,x,y,tot,cnt,mod;
int head[N],siz[N],dfn[N];
ll c[N],v[N],ans;
struct edge{
	int nxt,to;
}e[N<<1];
struct tree{
	int l,r;
	ll csum,vsum;
	#define l(p) t[p].l
	#define r(p) t[p].r
	#define cs(p) t[p].csum
	#define vs(p) t[p].vsum
}t[N<<2];
inline void add(int a,int b){e[++tot]={head[a],b};head[a]=tot;}
inline void add(int a,int b,string s){add(a,b);add(b,a);}
inline void dfs(int x,int fa){
	siz[x]=1;dfn[x]=++cnt;
	for(rint i=head[x];i;i=e[i].nxt){
		int y=e[i].to;
		if(y!=fa){
			dfs(y,x);
			siz[x]+=siz[y];
		}
	}
}
void build(int p,int l,int r){
	l(p)=l;r(p)=r;
	if(l==r){cs(p)=c[l],vs(p)=v[l];return;}
	int Mid=l+r>>1;
	build(ls,l,Mid),build(rs,Mid+1,r);
	cs(p)=cs(ls)*cs(rs)%mod;
	vs(p)=vs(ls)*vs(rs)%mod;
	return ;
}
ll query(int p,int l,int r,int cmd){
	if(l<=l(p)&&r(p)<=r){
		if(cmd) return vs(p);
		else return cs(p);
	}
	ll res=1;int Mid=l(p)+r(p)>>1;
	if(l<=Mid) res=query(ls,l,r,cmd);
	if(r>Mid) res=res*query(rs,l,r,cmd)%mod;
	return res;
}
int main(){
	read(n,mod);
	for(rint i=1;i<n;i++) read(x,y),add(x,y,"Mr.Az");
	dfs(1,0);
	for(rint i=1;i<=n;i++) read(c[dfn[i]]);
	for(rint i=1;i<=n;i++) read(v[dfn[i]]);
	build(1,1,n);
	for(rint i=1;i<=n;i++){
		ll res=1;
		if(dfn[i]!=1) res=query(1,1,dfn[i]-1,1);
		res=res*query(1,dfn[i],dfn[i]+siz[i]-1,0)%mod;
		if(dfn[i]+siz[i]<=n) res=res*query(1,dfn[i]+siz[i],n,1)%mod;
		ans=(ans+res)%mod;
	}
	printf("%lld\n",ans);
    return 0;
}
posted @ 2023-05-27 14:49  Mr_Azz  阅读(14)  评论(0编辑  收藏  举报