关于树上路径类问题,大概率是点分治。
考虑每次计算 \(u\) 的子树中所有 LCA 是 \(u\) 的点对 \((x,y)\) 所构成的路径的贡献最大值。(我们先不考虑一些也许需要特判之类的特殊情况,比如 \(x\) 是 \(u\))
容易发现这条路径认为两段:\(x\Rightarrow u\) 以及 \(y\Rightarrow u\) 两部分。显然我们可以一遍 dfs 求出 \(u\) 到每个子树内节点的路径的贡献。然后考虑合并路径。
显然:设 \(x\) 在 \(u\) 的儿子 \(v\) 的子树内,\(y\) 在 \(u\) 的另一儿子 \(v'\) 的子树内,当 \(u\Rightarrow v\) 这条边的颜色 \(c\) 等于 \(u\Rightarrow v'\) 这条边的颜色 \(c'\) 时答案是两条路径贡献之和减去一个颜色 \(c=c'\) 的贡献;否则答案就是两条路径相加。显然第二种比较无脑先去考虑它。注意到第二种情况下,这两个点一定不在同一颗 \(x\) 的儿子所形成的子树内。
为了下文方便起见,我们称 \(x\) 所在的子树,就是指那个是 \(x\) 的祖先的,同时也是 \(u\) 的儿子的节点 \(v\) 为根所形成的子树(有点绕,多看几遍就理解了),把 \(v\) 记作 \(root_x\)。称 \(x\) 的子树的颜色就是指 \(u\Rightarrow root_x\) 这条边的颜色,记作 \(color_x\).
当我们把所有儿子按照 \(color_x\) 划分成若干类以后,我们枚举每一个节点 \(x\) 所贡献的路径 \(u\Rightarrow x\),然后考虑寻找另一条路径。开一个桶 \(bucket(i)\) 维护可选的所有节点 \(y\) 中满足 \(u\Rightarrow y\) 的长度为 \(i\) 的所有路径中最大的那个贡献。一个节点可选就是说该节点所在的所有 \(color\) 与它相同的点都计算完了,此时把这些点更新桶的数值。寻找另一条路径的时候,设当前路径长度 \(d\),显然答案在 \([\max\{0,l-d\},min\{r-d,maxd\}\),其中 \(maxd\) 代表已经加入过桶的路径的最长深度,这个变量每次加入桶的时候维护就好了。注意答案是取这个区间的最大值。所以这个部分是一个“有多个取值区间,求每个取值区间最大值”的过程,显然当 \(d\) 升序的时候区间两端点全部单调非严格递减,因此显然可以用单调队列维护。那么计算第一类情况同理,只不过要按照 \(color_x\) 分过以后再按照所在子树分类,统计到一个新颜色,清空桶,统计到一个新子树,将之前那颗新子树的答案加入桶。
但是注意:单调队列的复杂度,总体是 \(O(maxd)\) 的,看上去没问题,但以第二类计算为例,每次碰到一个不同的 \(color_x\) 都要花费 \(O(maxd)\) 的时间,那么每次 \(solve\) 复杂度其实与 \(O(n^2)\) 同级(虽然没有这么极端但绝对会TLE没得说)。不过注意实际上 \(maxd\) 之和是与 \(n\) 同级的,坏就坏在如果你一开始统计了一颗 \(maxd\) 大的子树,后面每次单调队列的 \(maxd\) 都大了,换言之,我们要优先使用小的 \(maxd\)。所以计算第二部分的时候,我们按照每颗子树的 \(maxd\) 排序。但实际上第二部分是优先按照 \(color_x\) 划分的,所以我们还要计算每个颜色的子树的 \(maxd\) 的最大值,用这个对 \(color\) 排序。
这样整个复杂度就与 \(O(n)\) 同级(或者是 \(O(n log n)\),瓶颈在sort所以不很重要)
那么因为排序了,点分治的复杂度就是 \(O(n log^3n)\) 级别的,既然时限 \(2\) 秒那就是随便过了(
//BJOI,2017
#include<bits/stdc++.h>
#define rep(i,a,b) for(ll i=(a);i<=(b);i++)
#define per(i,a,b) for(ll i=(a);i>=(b);i--)
#define next Cry_For_theMoon
#define type Cry_for_theMoon
typedef long long ll;
using namespace std;
const int MAXN=2e5+10,MAXM=4e5+10,INF=2e9+10;
struct Edge{
int u,v,w;
}edge[MAXM];
int first[MAXN],next[MAXM],tot;
int n,m,l,r,c[MAXN],u,v,w,ans=-INF;
int vis[MAXN],sz[MAXN],rt;
int bucket[MAXN],kind[MAXN],depth[MAXN],sum[MAXN],arr[MAXN],cnt;
int type[MAXN],maxdepth[MAXN],typedepth[MAXN];
int q[MAXN];
void addedge(int,int,int);
void dfs1(int,int,int);
void dfs2(int,int);
void divide(int,int);
void solve(int);
void calc(int,int,int,int);
bool cmp1(const int,const int);
bool cmp2(const int,const int);
int main(){
scanf("%d%d%d%d",&n,&m,&l,&r);
rep(i,1,m)scanf("%d",&c[i]);
rep(i,1,n-1){
scanf("%d%d%d",&u,&v,&w);
addedge(u,v,w);addedge(v,u,w);
}
divide(1,n);
printf("%d",ans);
return 0;
}
void addedge(int u,int v,int w){
edge[++tot].u=u;edge[tot].v=v;edge[tot].w=w;
next[tot]=first[u];first[u]=tot;
}
void dfs1(int u,int fa,int all){
sz[u]=1;int flag=1;
for(int j=first[u];j;j=next[j]){
int v=edge[j].v;
if(vis[v] || v==fa)continue;
dfs1(v,u,all);
sz[u]+=sz[v];
if(sz[v]>all/2)flag=0;
}
if(flag && (all-sz[u])<=all/2)rt=u;
}
void dfs2(int u,int fa){
sz[u]=1;
for(int j=first[u];j;j=next[j]){
int v=edge[j].v;
if(vis[v] || v==fa)continue;
dfs2(v,u);
sz[u]+=sz[v];
}
}
void divide(int u,int all){
dfs1(u,0,all);
dfs2(rt,0);
solve(rt);
vis[rt]=1;
for(int j=first[rt];j;j=next[j]){
int v=edge[j].v;
if(vis[v])continue;
divide(v,sz[v]);
}
}
void solve(int u){
cnt=0;depth[u]=0;
for(int j=first[u];j;j=next[j]){
int v=edge[j].v;
if(vis[v])continue;
kind[v]=v;
depth[v]=1;
sum[v]=c[edge[j].w];
maxdepth[v]=1;
typedepth[edge[j].w]=0;
type[v]=edge[j].w;
calc(v,u,v,edge[j].w);
}
rep(i,1,cnt){
typedepth[type[arr[i]]]=max(typedepth[type[arr[i]]],depth[arr[i]]);
}
bucket[0]=-INF;
sort(arr+1,arr+1+cnt,cmp2);
int pre=1;
int head=1,rear=1,preL=0,maxd=0;
rep(i,1,cnt){
if(depth[arr[i]]>=l && depth[arr[i]]<=r){ans=max(ans,sum[arr[i]]);}
if(type[arr[i]]!=type[arr[i-1]]){
head=rear=1;
rep(j,pre,i-1){
bucket[depth[arr[j]]]=max(bucket[depth[arr[j]]],sum[arr[j]]);
maxd=max(maxd,depth[arr[j]]);
}
int L=max(0,l-depth[arr[i]]),R=min(r-depth[arr[i]],maxd);
per(j,R,L){
if(bucket[j]!=-INF){
while(head<rear && bucket[j]>=bucket[q[rear-1]])rear--;
q[rear++]=j;
}
}
if(head<rear){
ans=max(ans,sum[arr[i]]+bucket[q[head]]);
}
pre=i;preL=L;
continue;
}
int L=max(0,l-depth[arr[i]]),R=min(r-depth[arr[i]],maxd);
while(head<rear && q[head]>R)head++;
per(j,preL-1,L){
if(bucket[j]!=-INF){
while(head<rear && bucket[j]>=bucket[q[rear-1]])rear--;
q[rear++]=j;
}
}
if(head<rear){
ans=max(ans,sum[arr[i]]+bucket[q[head]]);
}
preL=L;
}
sort(arr+1,arr+1+cnt,cmp1);
bucket[0]=bucket[1]=-INF;
rep(i,1,cnt){
bucket[depth[arr[i]]]=-INF;
}
int presection=1;
head=1,rear=1,maxd=0,preL=0,pre=1;
rep(i,1,cnt){
if(type[kind[arr[i]]]!=type[kind[arr[i-1]]]){
bucket[0]=bucket[1]=-INF;
rep(j,presection,i-1){
bucket[depth[arr[j]]]=-INF;
}
head=1,rear=1,maxd=0,preL=0;
pre=i;
presection=i;
continue;
}
if(kind[arr[i]]!=kind[arr[i-1]]){
head=1,rear=1;
rep(j,pre,i-1){
bucket[depth[arr[j]]]=max(bucket[depth[arr[j]]],sum[arr[j]]);
maxd=max(maxd,depth[arr[j]]);
}
int L=max(0,l-depth[arr[i]]),R=min(r-depth[arr[i]],maxd);
per(j,R,L){
if(bucket[j]!=-INF){
while(head<rear && bucket[j]>=bucket[q[rear-1]])rear--;
q[rear++]=j;
}
}
if(head<rear){
ans=max(ans,sum[arr[i]]+bucket[q[head]]-c[type[arr[i]]]);
}
preL=L,pre=i;
continue;
}
int L=max(0,l-depth[arr[i]]),R=min(r-depth[arr[i]],maxd);
while(head<rear && q[head]>R)head++;
per(j,preL-1,L){
if(bucket[j]!=-INF){
while(head<rear && bucket[j]>=bucket[q[rear-1]])rear--;
q[rear++]=j;
}
}
if(head<rear){
ans=max(ans,sum[arr[i]]+bucket[q[head]]-c[type[arr[i]]]);
}
preL=L;
}
}
void calc(int u,int fa,int root,int ed){
arr[++cnt]=u;
bucket[depth[u]]=-INF;
for(int j=first[u];j;j=next[j]){
int v=edge[j].v;
if(vis[v] || fa==v)continue;
kind[v]=kind[u];
depth[v]=depth[u]+1;
sum[v]=sum[u];
type[v]=type[u];
if(ed!=edge[j].w)sum[v]+=c[edge[j].w];
maxdepth[root]=max(maxdepth[root],depth[v]);
calc(v,u,root,edge[j].w);
}
}
bool cmp1(const int a,const int b){
if(type[a]!=type[b]){
if(typedepth[type[a]]!=typedepth[type[b]])return typedepth[type[a]]<typedepth[type[b]];
return type[a]<type[b];
}
if(kind[a]!=kind[b]){
if(maxdepth[kind[a]]!=maxdepth[kind[b]])return maxdepth[kind[a]]<maxdepth[kind[b]];
return kind[a]<kind[b];
}
return depth[a]<depth[b];
}
bool cmp2(const int a,const int b){
if(type[a]!=type[b]){
if(typedepth[type[a]]!=typedepth[type[b]])return typedepth[type[a]]<typedepth[type[b]];
return type[a]<type[b];
}
return depth[a]<depth[b];
}