[bzoj4860] [loj#2179] [BJOI2017] 树的难题

题意简述

\(n\) 个点的无根树。
树上每条边有颜色。共 \(m\) 种颜色,第 \(i\) 种颜色有权值 \(c_i\)
对于树上一条简单路径,路径上经过的所有边按顺序组成一个颜色序列,序列可以划分成若干个相同颜色段。定义路径权值为颜色序列上每个同颜色段的颜色权值之和。
求长度在 \([l,r]\) 的路径权值最大值。

想法

容易想到点分治。然后大体框架就有了。
问题是如何统计长度在 \([l,r]\) 的路径的权值。
容易想到在访问子节点时,同一个颜色的挨着访问,将访问完的路径分为“不同颜色”与“相同颜色”分别按长度排序,然后单调队列。
然后TLE了几个点……

问题在于,如果我们访问的第一个颜色的最长路径很长的话,访问后面点时每个点都进行一次单调队列,相当于这个“长路径”访问了很多很多遍。
于是改变一下访问顺序。
仍是同一个颜色的挨着访问,维护两个单调队列。
将颜色按照此颜色的最大深度从小到大排序,将同一颜色中的点按该点可达的最大深度从小到大排序。
每个点访问完后,与该颜色的单调队列跑一次;一个颜色中的点都访问后,将此颜色的单调队列合并至总的单调队列中。
由于有排序,总复杂度 \(O(nlog^2n)\)

还有一种线段树的【暴力】做法:
仍是同一个颜色的挨着访问。维护两个线段树,一个表示相同颜色,一个表示不同颜色。
访问完一个点后,将访问到的各长度的路径在线段树中找区间最大值更新答案。
访问完一个颜色后,两个线段树合并。合并复杂度 \(O(nlogn)\)
总复杂度 \(O(nlog^2n)\),但是常数较大。

总结

我就是觉得复杂度这个东西太玄妙了!!爱了爱了!
顺序很重要!!

代码

写+调的我要裂开了。。。

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>

#define INF 1000000000

using namespace std;

int read(){
	int x=0,f=1;
	char ch=getchar();
	while(!isdigit(ch) && ch!='-') ch=getchar();
	if(ch=='-') f=-1,ch=getchar();
	while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
	return x*f;
}

const int N = 200005;
typedef long long ll;
typedef pair<int,int> Pr;

int n,m,L,R,cv[N];

struct edge{
	int u,v,c;
	bool operator < (const edge &b) const{ return c<b.c; }
}ed[N];
struct node{
	int v,c;
	node *nxt;
}pool[N*2],*h[N];
int cnt;
void addedge(int u,int v,int c){
	node *p=&pool[++cnt],*q=&pool[++cnt];
	p->v=v;p->nxt=h[u];h[u]=p;p->c=c;
	q->v=u;q->nxt=h[v];h[v]=q;q->c=c;
}

int root,all,sz[N],mx[N],vis[N];
void getrt(int u,int fa){
	int v;
	sz[u]=1; mx[u]=0;
	for(node *p=h[u];p;p=p->nxt)
		if((v=p->v)!=fa && !vis[v]){
			getrt(v,u);
			sz[u]+=sz[v];
			mx[u]=max(mx[u],sz[v]);
		}
	mx[u]=max(mx[u],all-sz[u]);
	if(mx[u]<mx[root]) root=u;
}
int dis[N];
void dfs_sz(int u,int fa){
	int v;
	sz[u]=1;
	dis[u]=0;
	for(node *p=h[u];p;p=p->nxt)
		if((v=p->v)!=fa && !vis[v]){
			dfs_sz(v,u);
			dis[u]=max(dis[u],dis[v]);
			sz[u]+=sz[v];
		}
	dis[u]++;
}
ll b[N],pre[N],cur[N];
void cal(int u,int fa,int c,ll sum,int len){
	cur[len]=max(sum+cv[c],cur[len]);
	if(len>=R) return;
	int v;
	for(node *p=h[u];p;p=p->nxt)
		if((v=p->v)!=fa && !vis[v]){
			if(p->c==c) cal(v,u,c,sum,len+1);
			else cal(v,u,p->c,sum+cv[c],len+1);
		}
}
ll ans;
int que[N],hd,tl,son[N],col[N],mxd[N];
bool cmp(int x,int y){
	if(col[x]==col[y]) return dis[x]<dis[y]; /**/
	if(mxd[col[x]]==mxd[col[y]]) return col[x]<col[y];
	return mxd[col[x]]<mxd[col[y]];
}

ll mxb[N];
void work(int u){
	int v,lastc=0,dep,tb=0,tp=0;
	vis[u]=1;
	
	int sn=0;
	for(node *p=h[u];p;p=p->nxt) 
		if(!vis[v=p->v]){
			son[sn++]=v;
			col[v]=p->c;
			dfs_sz(v,u);
			dis[v]=min(dis[v],R);
			mxd[col[v]]=max(mxd[col[v]],dis[v]);
		} 
	sort(son,son+sn,cmp);
	for(int i=0;i<sn;i++) mxd[col[son[i]]]=0;
	
	for(int k=0;k<sn;k++){
		v=son[k];
		cal(v,u,col[v],0,1);
		for(int i=L;i<=dis[v];i++) if(i<=R) ans=max(ans,cur[i]);
		if(col[v]!=lastc){
			hd=tl=0;
			for(int i=tp,j=1;i>0;i--){
				while(j<=tb && j+i<=R){
					while(hd<tl && b[que[tl-1]]<=b[j]) tl--;
					que[tl++]=j++;
				}
				while(hd<tl && que[hd]+i<L) hd++;
				if(hd<tl) ans=max(ans,b[que[hd]]+pre[i]);
			}
			tb=tp;
			for(int i=1;i<=tp;i++) b[i]=max(b[i],pre[i]),pre[i]=-INF;/**/
			tp=dis[v];
			for(int i=1;i<=dis[v];i++) pre[i]=cur[i];
		}
		else{
			hd=tl=0;
			for(int i=dis[v],j=1;i>0;i--){
				while(j<=tp && j+i<=R){
					while(hd<tl && pre[que[tl-1]]<=pre[j]) tl--;
					que[tl++]=j++;
				}
				while(hd<tl && que[hd]+i<L) hd++;
				if(hd<tl) ans=max(ans,pre[que[hd]]+cur[i]-cv[col[v]]);
			} 
			for(int i=1;i<=dis[v];i++) pre[i]=max(pre[i],cur[i]);
			tp=dis[v];
		}
		lastc=col[v];
		for(int i=1;i<=dis[v];i++) cur[i]=-INF;
	}
	hd=tl=0;
	for(int i=tp,j=1;i>0;i--){
		while(j<=tb && j+i<=R){
			while(hd<tl && b[que[tl-1]]<=b[j]) tl--;
			que[tl++]=j++;
		}
		while(hd<tl && que[hd]+i<L) hd++;
		if(hd<tl) ans=max(ans,b[que[hd]]+pre[i]);
	}
	//clear
	for(int i=0;i<=tb;i++) b[i]=-INF;
	for(int i=0;i<=tp;i++) pre[i]=-INF;
	
	for(node *p=h[u];p;p=p->nxt)
		if(!vis[v=p->v]){
			root=0; all=sz[v]; getrt(v,u);
			work(root);
		}
}

int main()
{
	n=read(); m=read(); L=read(); R=read();
	for(int i=1;i<=m;i++) cv[i]=read();
	for(int i=1;i<n;i++){ ed[i].u=read(); ed[i].v=read(); ed[i].c=read(); }
	sort(ed+1,ed+n);
	for(int i=1;i<n;i++) addedge(ed[i].u,ed[i].v,ed[i].c);
	
	sz[0]=mx[0]=n+1;
	root=0; all=n; getrt(1,0);
	
	for(int i=0;i<=n;i++) b[i]=-INF,pre[i]=-INF,cur[i]=-INF;
	ans=-INF;work(root);
	
	printf("%lld\n",ans);
	
	return 0;
}
posted @ 2020-02-20 19:47  秋千旁的蜂蝶~  阅读(162)  评论(0编辑  收藏  举报