hihocoder-1347 小h的树上的朋友(lca+线段树)

题目链接:

小h的树上的朋友

时间限制:18000ms
单点时限:2000ms
内存限制:512MB

描述

小h拥有n位朋友。每位朋友拥有一个数值Vi代表他与小h的亲密度。亲密度有可能发生变化。

岁月流逝,小h的朋友们形成了一种稳定的树状关系。每位朋友恰好对应树上的一个节点。

每次小h想请两位朋友一起聚餐,他都必须把连接两位朋友的路径上的所有朋友都一起邀请上。并且聚餐的花费是这条路径上所有朋友的亲密度乘积。

小h很苦恼,他需要知道每一次聚餐的花销。小h问小y,小y当然会了,他想考考你。

输入

输入文件第一行是一个整数n,表示朋友的数目,从1开始编号。

输入文件第二行是n个正整数Vi,表示每位朋友的初始的亲密度。

接下来n-1行,每行两个整数u和v,表示u和v有一条边。

然后是一个整数m,代表操作的数目。每次操作为两者之一:

0 u v 询问邀请朋友u和v聚餐的花费

1 u v 改变朋友u的亲密度为v

1<=n,m<=5*105

Vi<=109

输出

对于每一次询问操作,你需要输出一个整数,表示聚餐所需的花费。你的答案应该模1,000,000,007输出。

样例输入
3
1 2 3
1 2
2 3
5
0 1 2
0 1 3
1 2 3
1 3 5
0 1 3
样例输出
2
6
15
题意:
中文的就不说了;

思路:
显然是一个线段树的题;
先dfs,把树映射到区间上同时求出每个点到根节点的花费,
0的时候询问:先找到lca;再dis[u]*dis[v]*w[lca]/(dis[lca]*dis[lca]);可以费马小定理快速幂求逆;
1的时候更新:dfs的时候找到了每个点的包含此点所以子节点的区间,把这个区间的dis都更新同时还要更新w[u]我就是这两个问题写漏了改了一夜晚;

AC代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#include <bits/stdc++.h>
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
 
using namespace std;
 
#define For(i,j,n) for(int i=j;i<=n;i++)
#define mst(ss,b) memset(ss,b,sizeof(ss));
 
typedef  long long LL;
 
template<class T> void read(T&num) {
    char CH; bool F=false;
    for(CH=getchar();CH<'0'||CH>'9';F= CH=='-',CH=getchar());
    for(num=0;CH>='0'&&CH<='9';num=num*10+CH-'0',CH=getchar());
    F && (num=-num);
}
int stk[70], tp;
template<class T> inline void print(T p) {
    if(!p) { puts("0"); return; }
    while(p) stk[++ tp] = p%10, p/=10;
    while(tp) putchar(stk[tp--] + '0');
    putchar('\n');
}
 
const LL mod=1e9+7;
const double PI=acos(-1.0);
const int inf=1e9;
const int N=5e5+10;
const int maxn=1e3+10;
const double eps=1e-10;
 
LL w[N],dis[N];
vector<int>ve[N];
 
int n,in[N],a[2*N],dep[N],cnt=0,out[N];
 
LL pow_mod(LL x,LL y)
{
    LL s=1,base=x;
    while(y)
    {
        if(y&1)s=s*base%mod;
        base=base*base%mod;
        y>>=1;
    }
    return s;
}
 
void dfs(int x,int deep,int fa)
{
    cnt++;
    in[x]=cnt;
    a[cnt]=x;
    dep[x]=deep;
    int len=ve[x].size();
    For(i,0,len-1)
    {
        int y=ve[x][i];
        if(y==fa)continue;
        dis[y]=dis[x]*w[y]%mod;
        dfs(y,deep+1,x);
        cnt++;
        a[cnt]=x;
    }
    out[x]=cnt;
}
struct Tree
{
    int l,r,lca;
    LL dis;
}tr[8*N];
void pushdown(int o)
{
    tr[2*o].dis=tr[2*o].dis*tr[o].dis%mod;
    tr[2*o+1].dis=tr[2*o+1].dis*tr[o].dis%mod;
    tr[o].dis=1;
}
void build(int o,int L,int R)
{
    tr[o].l=L;
    tr[o].r=R;
    tr[o].dis=1;
    if(L==R)
    {
        tr[o].dis=dis[a[L]];
        tr[o].lca=a[L];
        return ;
    }
    int mid=(L+R)>>1;
    build(2*o,L,mid);
    build(2*o+1,mid+1,R);
    if(dep[tr[2*o].lca]>=dep[tr[2*o+1].lca])tr[o].lca=tr[2*o+1].lca;
    else tr[o].lca=tr[2*o].lca;
}
void update(int o,int L,int R,LL val)
{
    if(tr[o].l>=L&&tr[o].r<=R)
    {
        tr[o].dis=tr[o].dis*val%mod;
        return ;
    }
    int mid=(tr[o].l+tr[o].r)>>1;
 
    if(L>mid)update(2*o+1,L,R,val);
    else if(R<=mid)update(2*o,L,R,val);
    else {
        update(2*o,L,mid,val);
        update(2*o+1,mid+1,R,val);
    }
}
int querylca(int o,int L,int R)
{
 
        if(tr[o].l>=L&&tr[o].r<=R)return tr[o].lca;
        int mid=(tr[o].l+tr[o].r)>>1;
        if(R<=mid)return querylca(2*o,L,R);
        else if(L>mid)return querylca(2*o+1,L,R);
        else
        {
            int fl=querylca(2*o,L,mid),fr=querylca(2*o+1,mid+1,R);
            if(dep[fl]<=dep[fr])return fl;
            else return fr;
        }
}
LL query(int o,int pos)
{
    if(tr[o].l==tr[o].r&&tr[o].l==pos)return tr[o].dis;
    int mid=(tr[o].l+tr[o].r)>>1;
    pushdown(o);
    if(pos>mid)return query(2*o+1,pos);
    return query(2*o,pos);
}
int main()
{
        read(n);
        For(i,1,n)read(w[i]);
        int u,v;
        For(i,1,n-1)
        {
            read(u);read(v);
            ve[u].push_back(v);
            ve[v].push_back(u);
        }
        dis[1]=w[1];
        dfs(1,0,0);
        build(1,1,cnt);
        int q,f;
        read(q);
        while(q--)
        {
            read(f);read(u);read(v);
            if(f)
            {
                LL temp=w[u];
                w[u]=(LL)v;
                update(1,in[u],out[u],w[u]*pow_mod(temp,mod-2)%mod);
            }
            else
            {
                if(in[u]>in[v])swap(u,v);
                int lca=querylca(1,in[u],in[v]);
                LL temp=query(1,in[lca]);
                temp=pow_mod(temp,mod-2);
                temp=temp*temp%mod;
                LL ans=query(1,in[u])*query(1,in[v])%mod*temp%mod*w[lca]%mod;
                cout<<ans<<"\n";
            }
        }       
        return 0;
}

  

 

posted @   LittlePointer  阅读(562)  评论(2编辑  收藏  举报
努力加载评论中...
点击右上角即可分享
微信分享提示