[BJOI2017]树的难题
考虑点分治,然后问题转换成求过根节点的路径的最大值,这种东西一般性就是搞一个线段树然后每一次合并一下两个子树然后查最大值即可
其实我自己也不是很懂,这题确实很神仙
#include<stdio.h>
#include<stdlib.h>
#include<string.h>
#include<math.h>
#include<algorithm>
#include<queue>
#include<set>
#include<map>
#include<iostream>
using namespace std;
#define re register
#define ll long long
inline int gi(){
int f=1,sum=0;char ch=getchar();
while(ch>'9' || ch<'0'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0' && ch<='9'){sum=(sum<<3)+(sum<<1)+ch-'0';ch=getchar();}
return f*sum;
}
const int N=500010,Inf=1e9+10;
int n,m,l,r,f[N],rt,mx,tot,val[N],ans=-Inf;
struct node{
int v,ls,rs;
}t[N*100];
struct edge{int v,c;bool operator<(const edge &b)const{return c<b.c;}};
vector<edge>g[N];
int siz[N],dep[N],used[N],cnt,dis[N];
void getroot(int u,int ff){
siz[u]=1;int now=0;
for(int j=0;j<g[u].size();j++){
edge i=g[u][j];int v=i.v;if(v==ff || used[v])continue;
getroot(v,u);
now=max(now,siz[v]);siz[u]+=siz[v];
}
now=max(now,tot-siz[u]);
if(now<mx){mx=now;rt=u;}
}
void pushup(int o){t[o].v=max(t[t[o].ls].v,t[t[o].rs].v);}
void modify(int &o,int l,int r,int pos,int v){
if(!o){o=++cnt;t[o].v=-Inf,t[o].ls=t[o].rs=0;}
if(l==r){t[o].v=v;return;}
int mid=(l+r)>>1;
if(pos<=mid)modify(t[o].ls,l,mid,pos,v);
else modify(t[o].rs,mid+1,r,pos,v);
pushup(o);
}
int query(int o,int l,int r,int ql,int qr){
if(!o)return -Inf;
if(ql<=l && r<=qr)return t[o].v;
int mid=(l+r)>>1,ret=-Inf;
if(ql<=mid)ret=max(ret,query(t[o].ls,l,mid,ql,qr));
if(qr>mid)ret=max(ret,query(t[o].rs,mid+1,r,ql,qr));
return ret;
}
int merge(int x,int y,int l,int r){
if(!x || !y)return x+y;
if(l==r){t[x].v=max(t[x].v,t[y].v);return x;}
int mid=(l+r)>>1;
t[x].ls=merge(t[x].ls,t[y].ls,l,mid);
t[x].rs=merge(t[x].rs,t[y].rs,mid+1,r);
pushup(x);return x;
}
int hd,tl,q[N],vis[N],pre[N];
void bfs(int s,int C){
vis[s]=1;q[hd=tl=1]=s;pre[s]=C;
while(hd<=tl){
int u=q[hd++];
for(int i=0;i<g[u].size();i++){
int v=g[u][i].v,c=g[u][i].c;
if(vis[v] || used[v])continue;
if(pre[u]!=c)dis[v]=dis[u]+val[c];
else dis[v]=dis[u];
dep[v]=dep[u]+1;pre[v]=c;vis[v]=1;
q[++tl]=v;
}
}
for(int i=1;i<=tl;i++)vis[q[i]]=0;
}
void solve(int u){
cnt=0;
int rt1=0,rt2=0;
for(int i=0;i<g[u].size();i++){
int v=g[u][i].v,c=g[u][i].c;if(used[v])continue;
dep[v]=1;dis[v]=val[c];bfs(v,c);
if(i && g[u][i-1].c!=c)rt1=merge(rt1,rt2,1,r),rt2=0;
for(int j=1;j<=tl;j++){
f[dep[q[j]]]=max(f[dep[q[j]]],dis[q[j]]);
}
ans=max(ans,f[r]);
for(int j=1;j<r && j<=dep[q[tl]] && f[j]!=f[0];j++){
if(l<=j)ans=max(ans,f[j]);
int ret=max(query(rt1,1,r,max(1,l-j),r-j),
query(rt2,1,r,max(1,l-j),r-j)-val[c])+f[j];
ans=max(ans,ret);
}
for(int j=1;j<=dep[q[tl]] && j<r && f[j]!=f[0];j++)
modify(rt2,1,r,j,f[j]);
for(int j=1;j<=tl;j++)f[dep[q[j]]]=-Inf;
}
}
void divide(int u){
used[u]=1;solve(u);
for(int j=0;j<g[u].size();j++){
edge i=g[u][j];int v=i.v;if(used[v])continue;
mx=n;tot=siz[v];
getroot(v,u);
divide(rt);
}
}
int main(){
n=gi();m=gi();l=gi();r=gi();
for(int i=1;i<=m;i++)val[i]=gi();
for(int i=1;i<n;i++){
int u=gi(),v=gi(),c=gi();
g[u].push_back((edge){v,c});
g[v].push_back((edge){u,c});
}
for(int i=1;i<=n;i++)sort(g[i].begin(),g[i].end());
for(int i=0;i<=n;i++)f[i]=-Inf;t[0].v=-Inf;
tot=n;mx=n;
getroot(1,1);
divide(rt);
printf("%d\n",ans);
return 0;
}