题解 影子
这篇题解拖得有点久了……
这个「在经过的所有点上取最小点权」就很讨厌
暴力的话就直接枚举这个最小点权在哪里取到
于是……
- 类似这个题,要求树上路径中权值最小/最大的点/边参与计算,可以考虑将点权/边权排序后维护一个集合,按顺序向集合中加点/边,这样后加的点/边权值一定是当前最大/最小的,方便计算
而且有个结论:用一条边将两棵树连通,所得新树的直径端点一定是原连通块中直径的端点
证明:还不会,先咕着
memset二维数组的时候老老实实用sizeof 尤其是lca里那个fa[][],当成一维memset好几回了
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define ll long long
#define ld long double
#define usd unsigned
#define ull unsigned long long
//#define int long long
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
char buf[1<<21], *p1=buf, *p2=buf;
inline ll read() {
ll ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, m;
int head[N], size;
ll val[N];
struct edge{int to, next; ll val;}e[N<<1];
inline void add(int s, int t, ll w) {edge* k=&e[++size]; k->to=t; k->val=w; k->next=head[s]; head[s]=size;}
struct point{int rk; ll val; inline void build(int i, ll v_) {rk=i; val=v_;}}p[N];
inline bool operator < (point a, point b) {return a.val>b.val;}
namespace force{
ll dfs(int u, int fa, ll minn) {
ll t=0;
for (int i=head[u],v; i; i=e[i].next) {
v = e[i].to;
if (v!=fa && val[v]>=minn) {
t = max(t, dfs(v, u, minn)+e[i].val);
}
}
return t;
}
ll ask(int u, ll minn) {
ll t1=0, t2=0, t;
for (int i=head[u],v; i; i=e[i].next) {
v = e[i].to;
if (val[v]>=minn) {
t = dfs(v, u, minn)+e[i].val;
if (t>=t1) {
t2=t1; t1=t;
}
else if (t>t2) t2=t;
}
}
return t1+t2;
}
void solve() {
ll ans=0;
for (int i=1; i<=n; ++i) ans=max(ans, val[i]*ask(i, val[i]));
printf("%lld\n", ans);
}
}
namespace task{
ll ans, len[N], cse[10];
int pa[N], dep[N], fa[22][N], s1[N], s2[N], lg[N];
inline int find(int p) {return pa[p]==p?p:(pa[p]=find(pa[p]));}
void dfs(int u, int pa) {
for (int i=1; i<=20; ++i)
if (dep[u]>=(1<<i)) fa[i][u] = fa[i-1][fa[i-1][u]];
else break;
for (int i=head[u],v; i; i=e[i].next) {
v = e[i].to;
if (v!=pa) dep[v]=dep[u]+1, fa[0][v]=u, len[v]=len[u]+e[i].val, dfs(v, u);
}
}
int lca(int a, int b) {
if (dep[a]<dep[b]) swap(a, b);
while (dep[a]>dep[b]) a=fa[lg[dep[a]-dep[b]]-1][a];
if (a==b) return a;
for (int i=lg[dep[a]]-1; i>=0; --i)
if (fa[i][a]!=fa[i][b])
a=fa[i][a], b=fa[i][b];
return fa[0][a];
}
//inline int dis(int a, int b) {return len[a]+len[b]-2*len[lca(a, b)];}
#define dis(a, b) (len[a]+len[b]-len[lca(a, b)]*2)
ll init(int sa1, int sa2, int sb1, int sb2) {
ll maxn=0;
maxn=max(maxn, cse[1]=dis(sa1, sa2));
maxn=max(maxn, cse[2]=dis(sb1, sb2));
maxn=max(maxn, cse[3]=dis(sa1, sb1));
maxn=max(maxn, cse[4]=dis(sa1, sb2));
maxn=max(maxn, cse[5]=dis(sa2, sb1));
maxn=max(maxn, cse[6]=dis(sa2, sb2));
return maxn;
}
void solve() {
ans=0;
memset(dep, 0, sizeof(int)*(n+10));
memset(fa, 0, sizeof(fa));
memset(len, 0, sizeof(ll)*(n+10));
memset(s1, 0, sizeof(int)*(n+10));
memset(s2, 0, sizeof(int)*(n+10));
for (int i=1; i<=n; ++i) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
for (int i=1; i<=n; ++i) pa[i]=s1[i]=s2[i]=i;
dep[1]=1;
dfs(1, 0);
sort(p+1, p+n+1);
ll maxn;
for (int i=1,u,fu,f1; i<=n; ++i) {
//cout<<p[i].rk<<endl;
//cout<<p[i].val<<endl;
u=p[i].rk;
for (int j=head[u],v; j; j=e[j].next) {
v=e[j].to;
if (val[v]<val[u]) continue;
//cout<<v<<endl;
//cout<<u<<' '<<v<<endl;
//for (int i=1; i<=n; ++i) cout<<pa[i]<<' '; cout<<endl;
fu=find(u); f1=find(v);
//cout<<"fa: "<<fu<<' '<<f1<<endl;
if (fu==f1) continue;
pa[f1]=fu;
//cout<<"pos3"<<endl;
maxn=init(s1[fu], s2[fu], s1[f1], s2[f1]);
//for (int k=1; k<=6; ++k) cout<<cse[k]<<' '; cout<<endl;
for (int k=1; k<=6; ++k) if (cse[k]==maxn) {
//cout<<"k: "<<k<<endl;
switch (k) {
case 1: break;
case 2: s1[fu]=s1[f1]; s2[fu]=s2[f1]; break;
case 3: s2[fu]=s1[f1]; break;
case 4: s2[fu]=s2[f1]; break;
case 5: s1[fu]=s1[f1]; break;
case 6: s1[fu]=s2[f1]; break;
}
break;
}
}
fu=find(u);
//cout<<"upd: "<<p[i].val<<' '<<s1[fu]<<' '<<s2[fu]<<' '<<dis(s1[fu], s2[fu])<<endl;
if (s1[fu]&&s2[fu]) ans=max(ans, 1ll*p[i].val*dis(s1[fu], s2[fu]));
}
printf("%lld\n", ans);
//exit(0);
}
}
signed main()
{
#ifdef DEBUG
freopen("1.in", "r", stdin);
#endif
int T;
ll w;
T=read();
while (T--) {
memset(head, 0, sizeof(head));
size=0;
n=read();
for (int i=1; i<=n; ++i) p[i].build(i, val[i]=read());
for (int i=1,u,v; i<n; ++i) {
u=read(); v=read(); w=read();
add(u, v, w); add(v, u, w);
}
//force::solve();
task::solve();
}
return 0;
}