P9208 虚人「无」
Question 问题 P9208 虚人「无」
给定二元序列 \(\{(v_i,c_i)\}\) 和一棵以 \(1\) 为根的有根树。第 \(i\) 个点的点权是 \((v_i,c_i)\)。
- 定义一个非根节点的权值为其子树内的 \(c\) 的积乘上其子树补的 \(v\) 的积。
- 定义一个根节点的权值为其子树内的 \(c\) 的积。
形式化的讲,若 \(u\) 不为根节点,则 \(u\) 的权值 \(f_u\) 为:
\[f_u=\prod\limits_{v\in \operatorname{substree}(u)} c_v\times \prod\limits_{v\notin \operatorname{substree}(u)} v_v
\]
否则,其权值 \(f_u\) 为:
\[f_u=\prod\limits_{v=1}^n c_v
\]
试求整棵树所有节点的权值之和,答案对 \(m\) 取模。请注意:保证 \(\bm m\) 是质数。
Analysis 分析
定义:
\(dfn_x\) 为 \(x\) 的 \(dfs\) 序。
\(siz_x\) 为 \(x\) 的子树大小。
先讲正解:dfs 序 & 线段树查询
我们跑一遍 \(dfs\) 序之后,思考对于一个点 \(x\),其子节点所在的 \(dfs\) 序必然在区间 \(dfn_{x} \sim dfn_{x}+siz_{x}-1\) 之中。
接下来我们只需要维护区间内 \(c,v\) 的乘积,显然我们只需要用线段树维护一下乘积就行了。
以下是错解:逆元(可选择不看)
考虑一个非常自然就能想到的方法,看到取模,想用逆元。但是在本题不行。因为使用逆元需要当所有参与计算答案的值均存在逆元,即模数 \(m \perp c_i,v_i\)。
所以这道题可以通过特殊构造卡掉逆元做法。
Code 代码
#include<bits/stdc++.h>
#define rint register int
#define ls (p<<1)
#define rs (p<<1|1)
using namespace std;
typedef long long ll;
const int N=3e5+8;
int n,x,y,tot,cnt,mod;
int head[N],siz[N],dfn[N];
ll c[N],v[N],ans;
struct edge{
int nxt,to;
}e[N<<1];
struct tree{
int l,r;
ll csum,vsum;
#define l(p) t[p].l
#define r(p) t[p].r
#define cs(p) t[p].csum
#define vs(p) t[p].vsum
}t[N<<2];
inline void add(int a,int b){e[++tot]={head[a],b};head[a]=tot;}
inline void add(int a,int b,string s){add(a,b);add(b,a);}
inline void dfs(int x,int fa){
siz[x]=1;dfn[x]=++cnt;
for(rint i=head[x];i;i=e[i].nxt){
int y=e[i].to;
if(y!=fa){
dfs(y,x);
siz[x]+=siz[y];
}
}
}
void build(int p,int l,int r){
l(p)=l;r(p)=r;
if(l==r){cs(p)=c[l],vs(p)=v[l];return;}
int Mid=l+r>>1;
build(ls,l,Mid),build(rs,Mid+1,r);
cs(p)=cs(ls)*cs(rs)%mod;
vs(p)=vs(ls)*vs(rs)%mod;
return ;
}
ll query(int p,int l,int r,int cmd){
if(l<=l(p)&&r(p)<=r){
if(cmd) return vs(p);
else return cs(p);
}
ll res=1;int Mid=l(p)+r(p)>>1;
if(l<=Mid) res=query(ls,l,r,cmd);
if(r>Mid) res=res*query(rs,l,r,cmd)%mod;
return res;
}
int main(){
read(n,mod);
for(rint i=1;i<n;i++) read(x,y),add(x,y,"Mr.Az");
dfs(1,0);
for(rint i=1;i<=n;i++) read(c[dfn[i]]);
for(rint i=1;i<=n;i++) read(v[dfn[i]]);
build(1,1,n);
for(rint i=1;i<=n;i++){
ll res=1;
if(dfn[i]!=1) res=query(1,1,dfn[i]-1,1);
res=res*query(1,dfn[i],dfn[i]+siz[i]-1,0)%mod;
if(dfn[i]+siz[i]<=n) res=res*query(1,dfn[i]+siz[i],n,1)%mod;
ans=(ans+res)%mod;
}
printf("%lld\n",ans);
return 0;
}