hihocoder-1347 小h的树上的朋友(lca+线段树)
题目链接:
小h的树上的朋友
时间限制:18000ms
单点时限:2000ms
内存限制:512MB
描述
小h拥有n位朋友。每位朋友拥有一个数值Vi代表他与小h的亲密度。亲密度有可能发生变化。
岁月流逝,小h的朋友们形成了一种稳定的树状关系。每位朋友恰好对应树上的一个节点。
每次小h想请两位朋友一起聚餐,他都必须把连接两位朋友的路径上的所有朋友都一起邀请上。并且聚餐的花费是这条路径上所有朋友的亲密度乘积。
小h很苦恼,他需要知道每一次聚餐的花销。小h问小y,小y当然会了,他想考考你。
输入
输入文件第一行是一个整数n,表示朋友的数目,从1开始编号。
输入文件第二行是n个正整数Vi,表示每位朋友的初始的亲密度。
接下来n-1行,每行两个整数u和v,表示u和v有一条边。
然后是一个整数m,代表操作的数目。每次操作为两者之一:
0 u v 询问邀请朋友u和v聚餐的花费
1 u v 改变朋友u的亲密度为v
1<=n,m<=5*105
Vi<=109
输出
对于每一次询问操作,你需要输出一个整数,表示聚餐所需的花费。你的答案应该模1,000,000,007输出。
- 样例输入
-
3 1 2 3 1 2 2 3 5 0 1 2 0 1 3 1 2 3 1 3 5 0 1 3
- 样例输出
-
2 6 15
题意:
中文的就不说了;
思路:
显然是一个线段树的题;
先dfs,把树映射到区间上同时求出每个点到根节点的花费,
0的时候询问:先找到lca;再dis[u]*dis[v]*w[lca]/(dis[lca]*dis[lca]);可以费马小定理快速幂求逆;
1的时候更新:dfs的时候找到了每个点的包含此点所以子节点的区间,把这个区间的dis都更新同时还要更新w[u]我就是这两个问题写漏了改了一夜晚;
AC代码:123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174#include <bits/stdc++.h>
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
using
namespace
std;
#define For(i,j,n) for(int i=j;i<=n;i++)
#define mst(ss,b) memset(ss,b,sizeof(ss));
typedef
long
long
LL;
template
<
class
T>
void
read(T&num) {
char
CH;
bool
F=
false
;
for
(CH=
getchar
();CH<
'0'
||CH>
'9'
;F= CH==
'-'
,CH=
getchar
());
for
(num=0;CH>=
'0'
&&CH<=
'9'
;num=num*10+CH-
'0'
,CH=
getchar
());
F && (num=-num);
}
int
stk[70], tp;
template
<
class
T>
inline
void
print(T p) {
if
(!p) {
puts
(
"0"
);
return
; }
while
(p) stk[++ tp] = p%10, p/=10;
while
(tp)
putchar
(stk[tp--] +
'0'
);
putchar
(
'\n'
);
}
const
LL mod=1e9+7;
const
double
PI=
acos
(-1.0);
const
int
inf=1e9;
const
int
N=5e5+10;
const
int
maxn=1e3+10;
const
double
eps=1e-10;
LL w[N],dis[N];
vector<
int
>ve[N];
int
n,in[N],a[2*N],dep[N],cnt=0,out[N];
LL pow_mod(LL x,LL y)
{
LL s=1,base=x;
while
(y)
{
if
(y&1)s=s*base%mod;
base=base*base%mod;
y>>=1;
}
return
s;
}
void
dfs(
int
x,
int
deep,
int
fa)
{
cnt++;
in[x]=cnt;
a[cnt]=x;
dep[x]=deep;
int
len=ve[x].size();
For(i,0,len-1)
{
int
y=ve[x][i];
if
(y==fa)
continue
;
dis[y]=dis[x]*w[y]%mod;
dfs(y,deep+1,x);
cnt++;
a[cnt]=x;
}
out[x]=cnt;
}
struct
Tree
{
int
l,r,lca;
LL dis;
}tr[8*N];
void
pushdown(
int
o)
{
tr[2*o].dis=tr[2*o].dis*tr[o].dis%mod;
tr[2*o+1].dis=tr[2*o+1].dis*tr[o].dis%mod;
tr[o].dis=1;
}
void
build(
int
o,
int
L,
int
R)
{
tr[o].l=L;
tr[o].r=R;
tr[o].dis=1;
if
(L==R)
{
tr[o].dis=dis[a[L]];
tr[o].lca=a[L];
return
;
}
int
mid=(L+R)>>1;
build(2*o,L,mid);
build(2*o+1,mid+1,R);
if
(dep[tr[2*o].lca]>=dep[tr[2*o+1].lca])tr[o].lca=tr[2*o+1].lca;
else
tr[o].lca=tr[2*o].lca;
}
void
update(
int
o,
int
L,
int
R,LL val)
{
if
(tr[o].l>=L&&tr[o].r<=R)
{
tr[o].dis=tr[o].dis*val%mod;
return
;
}
int
mid=(tr[o].l+tr[o].r)>>1;
if
(L>mid)update(2*o+1,L,R,val);
else
if
(R<=mid)update(2*o,L,R,val);
else
{
update(2*o,L,mid,val);
update(2*o+1,mid+1,R,val);
}
}
int
querylca(
int
o,
int
L,
int
R)
{
if
(tr[o].l>=L&&tr[o].r<=R)
return
tr[o].lca;
int
mid=(tr[o].l+tr[o].r)>>1;
if
(R<=mid)
return
querylca(2*o,L,R);
else
if
(L>mid)
return
querylca(2*o+1,L,R);
else
{
int
fl=querylca(2*o,L,mid),fr=querylca(2*o+1,mid+1,R);
if
(dep[fl]<=dep[fr])
return
fl;
else
return
fr;
}
}
LL query(
int
o,
int
pos)
{
if
(tr[o].l==tr[o].r&&tr[o].l==pos)
return
tr[o].dis;
int
mid=(tr[o].l+tr[o].r)>>1;
pushdown(o);
if
(pos>mid)
return
query(2*o+1,pos);
return
query(2*o,pos);
}
int
main()
{
read(n);
For(i,1,n)read(w[i]);
int
u,v;
For(i,1,n-1)
{
read(u);read(v);
ve[u].push_back(v);
ve[v].push_back(u);
}
dis[1]=w[1];
dfs(1,0,0);
build(1,1,cnt);
int
q,f;
read(q);
while
(q--)
{
read(f);read(u);read(v);
if
(f)
{
LL temp=w[u];
w[u]=(LL)v;
update(1,in[u],out[u],w[u]*pow_mod(temp,mod-2)%mod);
}
else
{
if
(in[u]>in[v])swap(u,v);
int
lca=querylca(1,in[u],in[v]);
LL temp=query(1,in[lca]);
temp=pow_mod(temp,mod-2);
temp=temp*temp%mod;
LL ans=query(1,in[u])*query(1,in[v])%mod*temp%mod*w[lca]%mod;
cout<<ans<<
"\n"
;
}
}
return
0;
}
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步