BZOJ3999 [TJOI2015]旅游 【树剖 + 线段树】
题目
为了提高智商,ZJY准备去往一个新世界去旅游。这个世界的城市布局像一棵树。每两座城市之间只有一条路径可
以互达。每座城市都有一种宝石,有一定的价格。ZJY为了赚取最高利益,她会选择从A城市买入再转手卖到B城市
。由于ZJY买宝石时经常卖萌,因而凡是ZJY路过的城市,这座城市的宝石价格会上涨。让我们来算算ZJY旅游完之
后能够赚取的最大利润。(如a城市宝石价格为v,则ZJY出售价格也为v)
输入格式
第一行输入一个正整数N,表示城市个数。
接下来一行输入N个正整数表示每座城市宝石的最初价格p,每个宝石的初始价格不超过100。
第三行开始连续输入N-1行,每行有两个数字x和y。表示x城市和y城市有一条路径。城市编号从1开始。
下一行输入一个整数Q,表示询问次数。
接下来Q行,每行输入三个正整数a,b,v,表示ZJY从a旅游到b,城市宝石上涨v。
1≤ N≤50000, 1≤Q ≤50000
输出格式
对于每次询问,输出ZJY可能获得的最大利润,如果亏本则输出0。
输入样例
3
1 2 3
1 2
2 3
2
1 2 100
1 3 100
输出样例
1
1
题解
题意是树上可修改的两点间有序差值最大值
如果是序列上的话一个线段树维护左右区间最值和答案就可以分类讨论跨不跨过中点得出各个点的答案从而做到维护答案了
如果是树上的话,考虑树剖,将路径拆成了几段几段
答案要么在一段内,要么最值分别分布在两段内
我们就用每一段的答案更新答案,再单独拿出所有的\(O(logn)\)段,考虑相互影响来更新答案
细节上要注意树上路径的方向
#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#define LL long long int
#define Redge(u) for (int k = h[u],to; k; k = ed[k].nxt)
#define REP(i,n) for (int i = 1; i <= (n); i++)
#define BUG(s,n) for (int i = 1; i <= (n); i++) cout<<s[i]<<' '; puts("");
#define ls (u << 1)
#define rs (u << 1 | 1)
using namespace std;
const int maxn = 50005,INF = 1000000000;
inline int read(){
int out = 0,flag = 1; char c = getchar();
while (c < 48 || c > 57){if (c == '-') flag = -1; c = getchar();}
while (c >= 48 && c <= 57){out = (out << 3) + (out << 1) + c - 48; c = getchar();}
return out * flag;
}
int h[maxn],ne = 2;
struct EDGE{int to,nxt;}ed[maxn << 1];
inline void Build(int u,int v){
ed[ne] = (EDGE){v,h[u]}; h[u] = ne++;
ed[ne] = (EDGE){u,h[v]}; h[v] = ne++;
}
int n,A[maxn];
int fa[maxn],siz[maxn],son[maxn],top[maxn],dep[maxn],id[maxn],hash[maxn],cnt;
void dfs1(int u){
siz[u] = 1;
Redge(u) if ((to = ed[k].to) != fa[u]){
fa[to] = u; dep[to] = dep[u] + 1;
dfs1(to);
siz[u] += siz[to];
if (!son[u] || siz[son[u]] < siz[to]) son[u] = to;
}
}
void dfs2(int u,int flag){
id[u] = ++cnt; hash[cnt] = u;
top[u] = flag ? top[fa[u]] : u;
if (son[u]) dfs2(son[u],true);
Redge(u) if ((to = ed[k].to) != fa[u] && to != son[u])
dfs2(to,false);
}
int mx[maxn << 2],mn[maxn << 2],d1[maxn << 2],d2[maxn << 2],tag[maxn << 2];
void upd(int u){
mx[u] = max(mx[ls],mx[rs]);
mn[u] = min(mn[ls],mn[rs]);
d1[u] = max(max(d1[ls],d1[rs]),mx[rs] - mn[ls]);
d2[u] = max(max(d2[ls],d2[rs]),mx[ls] - mn[rs]);
}
void pd(int u){
if (tag[u]){
mx[ls] += tag[u]; mn[ls] += tag[u]; tag[ls] += tag[u];
mx[rs] += tag[u]; mn[rs] += tag[u]; tag[rs] += tag[u];
tag[u] = 0;
}
}
void build(int u,int l,int r){
if (l == r){
mx[u] = mn[u] = A[hash[l]];
return;
}
int mid = l + r >> 1;
build(ls,l,mid);
build(rs,mid + 1,r);
upd(u);
}
void add(int u,int l,int r,int L,int R,int v){
if (l >= L && r <= R){
mx[u] += v; mn[u] += v; tag[u] += v;
return;
}
pd(u);
int mid = l + r >> 1;
if (mid >= L) add(ls,l,mid,L,R,v);
if (mid < R) add(rs,mid + 1,r,L,R,v);
upd(u);
}
struct node{int mx,mn,d1,d2;};
node query(int u,int l,int r,int L,int R){
if (l >= L && r <= R) return (node){mx[u],mn[u],d1[u],d2[u]};
pd(u);
int mid = l + r >> 1;
if (mid >= R) return query(ls,l,mid,L,R);
if (mid < L) return query(rs,mid + 1,r,L,R);
node t1 = query(ls,l,mid,L,R),t2 = query(rs,mid + 1,r,L,R);
return (node){max(t1.mx,t2.mx),min(t1.mn,t2.mn),max(max(t1.d1,t2.d1),t2.mx - t1.mn),max(max(t1.d2,t2.d2),t1.mx - t2.mn)};
}
node st[2][100],e[100];
int Top[2],tot;
void solve(int u,int v,int w){
int p = 0,ans = 0; Top[0] = Top[1] = 0;
while (top[u] != top[v]){
if (dep[top[u]] < dep[top[v]]){
swap(u,v);
p ^= 1;
}
add(1,1,n,id[top[u]],id[u],w);
st[p][++Top[p]] = query(1,1,n,id[top[u]],id[u]);
u = fa[top[u]];
}
p ^= 1;
if (dep[u] > dep[v]){
swap(u,v);
p ^= 1;
}
add(1,1,n,id[u],id[v],w);
st[p][++Top[p]] = query(1,1,n,id[u],id[v]);
tot = 0;
for (int i = 1; i <= Top[0]; i++){
e[++tot] = st[0][i];
ans = max(ans,e[tot].d2);
}
for (int i = Top[1]; i; i--){
e[++tot] = st[1][i];
ans = max(ans,e[tot].d1);
}
int gmin = INF;
for (int i = 1; i <= tot; i++){
ans = max(ans,e[i].mx - gmin);
gmin = min(gmin,e[i].mn);
}
printf("%d\n",ans);
}
int main(){
n = read();
REP(i,n) A[i] = read();
for (int i = 1; i < n; i++) Build(read(),read());
dfs1(1);
dfs2(1,0);
build(1,1,n);
int m = read(),a,b,w;
while (m--){
a = read(); b = read(); w = read();
solve(a,b,w);
}
return 0;
}