题解:「COCI2019」 Transport
树上点对问题考虑点分治。
按照套路讨论经过一个点 \(p\) 的路径的情况。
对于这道题的每一个点对 \((u,v)\) 的路径可以分为 \(u\rightarrow p\) 和 \(p\rightarrow v\) 两段。
对于前者,需要知道走完这条路径后剩余的油量,记作 \(L_{u\rightarrow p}\) ,
对于后者,需要知道在允许油量小于 0 时,走这条路径时油量的最小值,记作 \(R_{p\rightarrow v}\) 。
这样只要满足 \(L_{u\rightarrow p} + R_{p\rightarrow v} \geq 0\) 时,点对 \((u,v)\) 成立。
处理出 \(L\) 和 \(R\) 后,将两者暴力排序,之后用双指针扫一遍就行了,不要忘了用容斥或者染色法避免重复计数。
时间复杂度 \(\mathcal{O}(n\log_2^2 n)\)
Code(C++):
#include<bits/stdc++.h>
#define forn(i,s,t) for(register int i=(s);i<=(t);++i)
using namespace std;
typedef long long LL;
const int N = 1e5+1e3;
template<typename T>inline T Max(T A, T B) {return A>B?A:B;}
template<typename T>inline T Min(T A, T B) {return A<B?A:B;}
struct List {int dir, nxt; LL lng; } E[N<<1];
int G[N], cnt;
inline void Add(int u, int v, LL w) {
E[++cnt].dir = v, E[cnt].nxt = G[u], G[u] = cnt;
E[cnt].lng = w;
}
int n; LL a[N], Ans;
bool vis[N];
int Mx[N], sz[N], rt, S;
void Getrt(int u, int fa) { // 找重心
sz[u] = 1, Mx[u] = 0;
for(int i=G[u];i;i=E[i].nxt) {
int v = E[i].dir;
if(vis[v] || v == fa) continue ;
Getrt(v, u);
sz[u] += sz[v];
Mx[u] = Max(Mx[u], sz[v]);
}
Mx[u] = Max(Mx[u], S - sz[u]);
(Mx[u] < Mx[rt]) && (rt = u);
}
LL L[N], R[N];
int f[N], l, r, Cl[N], Cr[N], now, Nu, idl[N], idr[N];
inline bool cmpl(int A, int B) {return L[A] < L[B];}
inline bool cmpr(int A, int B) {return R[A] < R[B];}
void dfs(int u, int fa, LL dist, LL Dn, LL Up) {
if(Up == 0) L[++l] = dist, Cl[l] = now;
R[++r] = -Dn, Cr[r] = now;
for(int i=G[u];i;i=E[i].nxt) {
int v = E[i].dir; LL w = E[i].lng;
if(vis[v] || v == fa) continue ;
dfs(v, u, dist - w + a[v], Min(Dn, dist - a[Nu] - w), Min(0ll, Up - w + a[v]));
}
}
void calc(int u) {
l = r = 0, Nu = u;
L[++l] = a[u], Cl[l] = u, f[u] = 0;
R[++r] = 0, Cr[r] = u;
for(int i=G[u];i;i=E[i].nxt) {
int v = E[i].dir; LL w = E[i].lng;
if(vis[v]) continue ;
now = v; f[now] = 0;
dfs(v, u, a[u] + a[v] - w, -w, Min(0ll, - w + a[v]));
}
forn(i,1,l) idl[i] = i; sort(idl+1, idl+l+1, cmpl);
forn(i,1,r) idr[i] = i; sort(idr+1, idr+r+1, cmpr);
int j = 1;
forn(i,1,l) { // 2 Pnt
while(j <= r && R[idr[j]] <= L[idl[i]]) f[Cr[idr[j++]]] ++ ;
Ans += j - 1 - f[Cl[idl[i]]];
}
}
void solve(int u) {
vis[u] = 1;
calc(u);
for(int i=G[u];i;i=E[i].nxt) {
int v = E[i].dir;
if(vis[v]) continue ;
S = sz[v] > sz[u] ? n - sz[u] : sz[v];
rt = 0, Getrt(v, v);
solve(rt);
}
}
int main() {
scanf("%d", &n);
forn(i,1,n) scanf("%lld", a+i);
forn(i,2,n) {
int u, v; LL w;
scanf("%d%d%lld", &u, &v, &w);
Add(u, v, w), Add(v, u, w);
}
S = Mx[rt = 0] = n;
Getrt(1, 1);
solve(rt);
printf("%lld\n", Ans);
return 0;
}