记得是JS提高冬令营讲树形dp的压轴题...... 当时看不懂,现在感觉这个 2700* 确实虚高。
分析:
钦定以 \(1\) 为根。
发现这个东西其实说白了就是 \((u,v)\) 路径边必须且只能减一次。然后这条链(指两点路径)上挂着若干颗别的子树,这些子树里的边权被减了两次。然后 \(p=lca(u,v)\) 往上往外还接一个连通块,这里边的边权也是减了两次的。
\(p\) 上面的连通块这部分显然是特殊的。链本身是平凡的,我们仅考虑链上挂子树这个东西。
先设 \(f(u)\) 为从 \(u\) 开始走,走回 \(u\),不离开 \(u\) 子树内的最大 2-Path 权值。显然:
其中 \(w_{u,v}\) 代表 \(u,v\) 两点连边的边权。
然后你发现其实链上最后子树和的贡献答案就是 \(f(u)\),然后减去 \(u\) 若干个儿子 \(v\) 也在路径上,每个 \(v\),都应该减去 \(\max\{f(v)-2\,\times,w_{u,v},0\}\)。
你发现这是个定值,设为 \(h(v)\)。那么链上最后子树和(还加上路径上所有点的点权)的贡献其实就是路径 \(f\) 之和,减去路径 \(h\) 之和,加上 \(h(p)\)。
好的,来看 \(p\) 向上的连通块的贡献,设为 \(g(p)\)。注意这里不要加上 \(a_p\)。
这东西长的就很换根:考虑 \(p\) 的父亲 \(fa\) 转移至 \(p\):
发现也能做。
然后树上静态询问路径和显然可以求 lca 加树上前缀和实现。(就是 \(sum_u\) 表示 \(u\) 向上累加到根的和这种)。
瓶颈在于求 \(lca\),不难做到 \(O((n+q)\log n)\)。
当然你可以用线性 lca 算法(Tarjan)做,因为这题不仅静态还离线.... 就是说理论上可以做到线性。
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define per(i,a,b) for(int i=(a);i>=(b);i--)
#define op(x) ((x&1)?x+1:x-1)
#define odd(x) (x&1)
#define even(x) (!odd(x))
#define lc(x) (x<<1)
#define rc(x) (lc(x)|1)
#define lowbit(x) (x&-x)
#define Max(a,b) (a>b?a:b)
#define Min(a,b) (a<b?a:b)
#define next Cry_For_theMoon
#define il inline
#define pb(x) push_back(x)
#define is(x) insert(x)
#define sit set<int>::iterator
#define mapit map<int,int>::iterator
#define pi pair<int,int>
#define ppi pair<int,pi>
#define pp pair<pi,pi>
#define fr first
#define se second
#define vit vector<int>::iterator
#define mp(x,y) make_pair(x,y)
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef double db;
using namespace std;
const int MAXN=4e5+10;
int n,q,a[MAXN],u[MAXN],v[MAXN];
vector<pi> e[MAXN];
int fa[MAXN],faw[MAXN];
ll f[MAXN],g[MAXN],h[MAXN];
int top[MAXN][20],power[20],depth[MAXN];
ll sumf[MAXN],sumh[MAXN],sumd[MAXN];
void dfs1(int u){
for(auto V:e[u]){
int v=V.fr,w=V.se;
if(v==fa[u])continue;
sumd[v]=sumd[u]+w;
fa[v]=u;faw[v]=w;
dfs1(v);
}
}
void dfs2(int u){
f[u]=a[u];
for(auto V:e[u]){
int v=V.fr,w=V.se;
if(v==fa[u])continue;
dfs2(v);
f[u]+=max(f[v]-w*2,0LL);
}
}
void dfs3(int u){
for(auto V:e[u]){
int v=V.fr,w=V.se;
if(v==fa[u])continue;
//换根
g[v]=max(g[u]+f[u]-w*2-max(f[v]-w*2,0LL),0LL);
dfs3(v);
}
}
void dfs4(int u){
depth[u]=depth[fa[u]]+1;
sumf[u]=sumf[fa[u]]+f[u];
sumh[u]=sumh[fa[u]]+h[u];
top[u][0]=fa[u];
rep(j,1,19)top[u][j]=top[top[u][j-1]][j-1];
for(auto V:e[u]){
int v=V.fr;
if(v==fa[u])continue;
depth[v]=depth[u]+1;
dfs4(v);
}
}
int lca(int u,int v){
if(depth[u]<depth[v])swap(u,v);
per(j,19,0){
if(depth[u]-power[j]>=depth[v])u=top[u][j];
}
if(u==v)return u;
per(j,19,0){
if(top[u][j]!=top[v][j]){
u=top[u][j];v=top[v][j];
}
}
return top[u][0];
}
void solve(int u,int v){
int p=lca(u,v);ll ret=0;
ret+=sumf[u]+sumf[v]-sumf[fa[p]]*2-f[p];
ret-=(sumh[u]+sumh[v]-sumh[fa[p]]*2-h[p]*2);
ret+=g[p];
ret-=(sumd[u]+sumd[v]-sumd[p]*2);
printf("%lld\n",ret);
}
int main(){
power[0]=1;rep(i,1,19)power[i]=power[i-1]*2;
scanf("%d%d",&n,&q);
rep(i,1,n)scanf("%d",&a[i]);
rep(i,1,n-1){
int u,v,w;scanf("%d%d%d",&u,&v,&w);
e[u].pb(mp(v,w));e[v].pb(mp(u,w));
}
rep(i,1,q){
scanf("%d%d",&u[i],&v[i]);
}
dfs1(1);
dfs2(1);
dfs3(1);
rep(i,1,n)h[i]=max(0LL,f[i]-2*faw[i]);
dfs4(1);
rep(i,1,q){
solve(u[i],v[i]);
}
return 0;
}