[BZOJ2870]最长道路 Tree 多解
[BZOJ2870]最长道路 Tree 多解
题面
给定一棵N个点的树,求树上一条链使得链的长度乘链上所有点中的最小权值所得的积最大。
其中链长度定义为链上点的个数。
分析
解法一
考虑点分治。对于每个分治中心,\(getdis()\) 把每个子树的点到分治中心的距离和路径中点的最小权值存入数组,然后暴力和前面子树进行合并。
这样的复杂度是显然不行的。(万一数据水呢,期望得分 70pts)
#include<bits/stdc++.h>
#define int long long
#define pii pair<int,int>
#define mp make_pair
#define fi first
#define se second
using namespace std;
const int N = 5e4+5;
void init(){
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
}
int read(){
int s=0,w=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
while(ch>='0'&&ch<='9')s=s*10+ch-'0',ch=getchar();
return s*w;
}
struct Edge{
int nex,to;
}edge[N<<1];
int head[N],siz[N],mxp[N],dis[N],mn[N],val[N];
int n,elen,root,ans,cnt1,cnt2,total;
pii s1[N],s2[N];
bool vis[N];
void addedge(int u,int v){
edge[++elen]={head[u],v};head[u]=elen;
edge[++elen]={head[v],u};head[v]=elen;
}
void getroot(int u,int fath){
siz[u]=1;mxp[u]=0;
for(int e=head[u];e;e=edge[e].nex){
int v=edge[e].to;
if(v==fath||vis[v])continue;
getroot(v,u);siz[u]+=siz[v];
mxp[u]=max(mxp[u],siz[v]);
}
mxp[u]=max(mxp[u],total-siz[u]);
if(mxp[u]<mxp[root])root=u;
}
void getdis(int u,int fath){
s2[++cnt2]=mp(dis[u],mn[u]);
for(int e=head[u];e;e=edge[e].nex){
int v=edge[e].to;
if(v==fath||vis[v])continue;
dis[v]=dis[u]+1;
mn[v]=min(mn[u],val[v]);
getdis(v,u);
}
}
void calc(int u){
for(int e=head[u],v;e;e=edge[e].nex)
if(!vis[v=edge[e].to]){
dis[v]=1;mn[v]=min(val[v],val[u]);
cnt2=0;getdis(v,u);
for(int i=1;i<=cnt2;++i){
ans=max(ans,(s2[i].fi+1)*s2[i].se);
for(int j=1;j<=cnt1;++j)
ans=max(ans,(s1[j].fi+s2[i].fi+1)*min(s1[j].se,s2[i].se));
}
for(int i=1;i<=cnt2;++i)
s1[++cnt1]=s2[i];
}
cnt1=0;
}
void divide(int u){
vis[u]=1;calc(u);
for(int e=head[u],v;e;e=edge[e].nex){
if(vis[v=edge[e].to])continue;
mxp[root=0]=INT_MAX;
total=siz[v];getroot(v,u);
divide(root);
}
}
signed main(){
//init();
n=read();
for(int i=1;i<=n;++i)
val[i]=read(),ans=max(ans,val[i]);
for(int i=1,u,v;i<n;++i)
u=read(),v=read(),addedge(u,v);
mxp[root=0]=INT_MAX;
total=n;getroot(1,0);
divide(root);
printf("%lld\n",ans);
return 0;
}
解法二
我们发现前者的复杂度主要是在于多个子树之间合并的复杂度不好降。
那么不妨考虑边分治,这样我们合并答案的时候就只有两个子树了。
考虑虚点的权值,因为如果一条路径过了虚点,那么一定过了点 \(u\) ,所以虚点权值赋为 \(u\) 的权值。虚边权值 0,实边权值 1。
三度化然后边分治,两个子树的路径存进两个数组,对着两个数组分别按照路径中的最小值排序,然后倒序枚举双指针即可。注意需要左右各枚举一次双指针。
#include<bits/stdc++.h>
#define int long long
#define pii pair<int,int>
#define mp make_pair
#define fi first
#define se second
using namespace std;
const int N = 2e5+5;
const int inf = 1e18;
void init(){
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
}
int read(){
int s=0,w=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
while(ch>='0'&&ch<='9')s=s*10+ch-'0',ch=getchar();
return s*w;
}
struct Edge{
int nex,to,dis;
}edge[N<<1],bian[N<<1];
int head[N],pre[N],val[N],siz[N];
int n,elen=1,blen,ans,total,lsc,rsc;
int rt1,rt2,bridge,mx;
pii ls[N],rs[N];
bool vis[N<<1];
void addbian(int u,int v,int w=0){
bian[++blen]={pre[u],v,w};pre[u]=blen;
bian[++blen]={pre[v],u,w};pre[v]=blen;
}
void addedge(int u,int v,int w){
edge[++elen]={head[u],v,w};head[u]=elen;
edge[++elen]={head[v],u,w};head[v]=elen;
}
void rebuild(int u,int fath){
int ff=0;
for(int e=pre[u],v;e;e=bian[e].nex){
if((v=bian[e].to)==fath)continue;
if(!ff)addedge(u,v,1),ff=u;
else{
int tmp=++total;val[tmp]=val[u];
addedge(ff,tmp,0);addedge(tmp,v,1);
ff=tmp;
}
rebuild(v,u);
}
}
void getroot(int u,int fath){
siz[u]=1;
for(int e=head[u],v;e;e=edge[e].nex){
if((v=edge[e].to)==fath||vis[e])continue;
getroot(v,u);siz[u]+=siz[v];
int now=max(siz[v],total-siz[v]);
if(now<mx)mx=now,rt1=u,rt2=v,bridge=e;
}
}
void getdis(int u,int fath,int mn,int dep,int tp){
if(tp==1)ls[++lsc]=mp(mn,dep);
else rs[++rsc]=mp(mn,dep);
for(int e=head[u],v;e;e=edge[e].nex){
if((v=edge[e].to)==fath||vis[e])continue;
getdis(v,u,min(mn,val[v]),dep+edge[e].dis,tp);
}
}
void divide(int u){
if(total==1)return;
rt1=rt2=bridge=lsc=rsc=0;mx=inf;
getroot(u,0);
//cout<<" rt1 "<<rt1<<" rt2 "<<rt2<<" siz "<<siz[rt1]<<" "<<siz[rt2]<<endl;
vis[bridge]=vis[bridge^1]=1;
getdis(rt1,0,val[rt1],0,0);
getdis(rt2,0,val[rt2],0,1);
sort(ls+1,ls+1+lsc);sort(rs+1,rs+1+rsc);
int mxdep=0,ptr=rsc;
for(int i=lsc;i>=1;--i){
while(ptr&&rs[ptr].fi>=ls[i].fi)
mxdep=max(mxdep,rs[ptr--].se);
ans=max(ans,(mxdep+edge[bridge].dis+ls[i].se+1)*ls[i].fi);
}
mxdep=0,ptr=lsc;
for(int i=rsc;i>=1;--i){
while(ptr&&ls[ptr].fi>=rs[i].fi)
mxdep=max(mxdep,ls[ptr--].se);
ans=max(ans,(mxdep+edge[bridge].dis+rs[i].se+1)*rs[i].fi);
}
int tmprt=rt2,tmpsiz=siz[rt2];
total=total-siz[rt2];divide(rt1);
total=tmpsiz;divide(tmprt);
}
signed main(){
//init();
total=n=read();
for(int i=1;i<=n;++i)
val[i]=read(),ans=max(ans,val[i]);
for(int i=1,u,v;i<n;++i)
u=read(),v=read(),addbian(u,v);
rebuild(1,0);divide(1);
printf("%lld\n",ans);
return 0;
}
解法三
那点分治就一定不可做么?我们其实可以点分治+最大值树状数组。
我们把路径最小值作为树状数组下标,链长度作为值。
那么,我们对于一个分治中心,先从左往右一个个地扫描子树,
对于每次扫描,我们先利用之前子树的点(已经插入BIT)和当前子树的点配对算贡献(因为BIT的下标是路径最小值,我们不妨每次只取 最小值 大于 当前子树的点到分治中心的最小权 的之前子树的点),然后将子树中的点插入 BIT。
显然,我们还需要从右往左扫一遍。
(其实对于上述的理解可以看代码)
#include<bits/stdc++.h>
using namespace std;
const int N = 50005;
const int M = 70005;
struct node
{
int dep,group,zz;
bool operator<(const apple &other)const
{
return group<other.group;
}
}e[N];
int tot,val[N],mn[N],dfn[N],group[N],vis[N],sz[N],dep[N];
vector<int>g[N];
int C[M];
void add(int x,int s)
{
while(x<=70000)
{
C[x]=max(C[x],s);
x+=x&-x;
}
}
int query(int x)
{
int ans=0;
while(x)
{
ans=max(ans,C[x]);
x-=x&-x;
}
return ans;
}
void clear(int x)
{
while(x<=70000)
{
C[x]=0;
x+=x&-x;
}
}
void dfs(int x,int fath)
{
dfn[++tot]=x;
sz[x]=1;
for(int i=0;i<g[x].size();i++)
{
int v=g[x][i];
if(v==fath||vis[v])continue;
dfs(v,x);
sz[x]+=sz[v];
}
}
void getdep(int x,int fath,int l)
{
group[x]=l;
for(int i=0;i<g[x].size();i++)
{
int v=g[x][i];
if(v==fath||vis[v])continue;
dep[v]=dep[x]+1;
getdep(v,x,l);
}
}
void getmn(int x,int fath)
{
for(int i=0;i<g[x].size();i++)
{
int v=g[x][i];
if(v==fath||vis[v])continue;
mn[v]=min(mn[x],val[v]);
getmn(v,x);
}
}
long long merg(int x)
{
tot=0;
dfs(x,0);
if(sz[x]==1)
{
vis[x]=1;
return 0;
}
int tmp=INT_MAX,w;
for(int i=1;i<=tot;i++)
{
int mx=sz[x]-sz[dfn[i]];
for(int j=0;j<g[dfn[i]].size();j++)
{
int v=g[dfn[i]][j];
if(vis[v]||sz[v]>sz[dfn[i]])continue;
mx=max(mx,sz[v]);
}
if(tmp>mx)tmp=mx,w=dfn[i];
}
vis[w]=1,group[w]=-1,dep[w]=0,mn[w]=val[w];
getmn(w,0);
for(int i=0;i<g[w].size();i++)
{
int v=g[w][i];
if(vis[v])continue;
dep[v]=1;
getdep(v,w,i);
}
for(int i=1;i<=tot;i++)e[i].group=group[dfn[i]],e[i].dep=dep[dfn[i]],e[i].zz=mn[dfn[i]];
sort(e+1,e+tot+1);
long long ans=0;
for(int i=1;i<=tot;)
{
int wz=i;
while(wz<tot&&e[wz+1].group==e[wz].group)wz++;
for(int j=i;j<=wz;j++)ans=max(ans,1ll*e[j].zz*(query(66001-e[j].zz)+e[j].dep+1));
for(int j=i;j<=wz;j++)add(66001-e[j].zz,e[j].dep);
i=wz+1;
}
for(int i=1;i<=tot;i++)clear(66001-e[i].zz);
for(int i=tot;i>=1;)
{
int wz=i;
while(wz>1&&e[wz-1].group==e[wz].group)wz--;
for(int j=wz;j<=i;j++)ans=max(ans,1ll*e[j].zz*(query(66001-e[j].zz)+e[j].dep+1));
for(int j=wz;j<=i;j++)add(66001-e[j].zz,e[j].dep);
i=wz-1;
}
for(int i=1;i<=tot;i++)clear(66001-e[i].zz);
for(int i=0;i<g[w].size();i++)
{
int v=g[w][i];
if(vis[v])continue;
ans=max(ans,merg(v));
}
return ans;
}
int main()
{
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
int n;
cin>>n;
for(int i=1;i<=n;i++)scanf("%d",&val[i]);
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
long long ans=merg(1);
for(int i=1;i<=n;i++)ans=max(ans,1ll*val[i]);
cout<<ans<<endl;
return 0;
}
解法四
我们其实可以用并查集维护直径。
我们从大到小加边,记录连通块的直径的两个端点,合并连通块的时候更新直径,更新答案显然是用 最新加的边*直径。
有一个引理是,合并两个联通块时,新直径的端点应该在原来 4 个端点之中。
(画图用反证法易证)
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N = 5e4+5;
void init(){
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
}
int read(){
int s=0,w=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
while(ch>='0'&&ch<='9')s=s*10+ch-'0',ch=getchar();
return s*w;
}
struct Edge{int nex,to;}edge[N<<1];
struct Gruop{int l,r;}gp[N];
struct Node{int v,id;}nod[N];
int head[N],f[N][20],fath[N],dep[N],vis[N];
int n,elen,t,ans;
void addedge(int u,int v){
edge[++elen]={head[u],v};head[u]=elen;
edge[++elen]={head[v],u};head[v]=elen;
}
bool cmp(Node a,Node b){return a.v>b.v;}
void dfs(int u){
for(int i=1;i<=t;++i)
f[u][i]=f[f[u][i-1]][i-1];
for(int e=head[u],v;e;e=edge[e].nex){
if((v=edge[e].to)==f[u][0])continue;
dep[v]=dep[f[v][0]=u]+1;
dfs(v);
}
}
int find(int x){return fath[x]==x?x:fath[x]=find(fath[x]);}
int lca(int x,int y){
if(dep[x]<dep[y])swap(x,y);
for(int i=t;i>=0;--i)
if(dep[f[x][i]]>=dep[y])x=f[x][i];
if(x==y)return x;
for(int i=t;i>=0;--i)
if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
return f[x][0];
}
int dis(int x,int y){return dep[x]+dep[y]-2*dep[lca(x,y)]+1;}
int merge(int x,int y){
int fx=find(x),fy=find(y);fath[fx]=fy;
int res=0,temp,l=0,r=0;
temp=dis(gp[fx].l,gp[fx].r);if(temp>res)res=temp,l=gp[fx].l,r=gp[fx].r;
temp=dis(gp[fx].l,gp[fy].l);if(temp>res)res=temp,l=gp[fx].l,r=gp[fy].l;
temp=dis(gp[fx].l,gp[fy].r);if(temp>res)res=temp,l=gp[fx].l,r=gp[fy].r;
temp=dis(gp[fx].r,gp[fy].l);if(temp>res)res=temp,l=gp[fx].r,r=gp[fy].l;
temp=dis(gp[fx].r,gp[fy].r);if(temp>res)res=temp,l=gp[fx].r,r=gp[fy].r;
temp=dis(gp[fy].l,gp[fy].r);if(temp>res)res=temp,l=gp[fy].l,r=gp[fy].r;
return gp[fy]={l,r},res;
}
signed main(){
//init();
n=read();t=log(n)/log(2)+1;
for(int i=1;i<=n;++i)
nod[i]={read(),i},gp[i]={i,i},fath[i]=i;
for(int i=1,u,v;i<n;++i)
u=read(),v=read(),addedge(u,v);
dfs(1);
sort(nod+1,nod+1+n,cmp);
for(int i=1;i<=n;++i){
int now=nod[i].id,l=0;vis[now]=1;
for(int e=head[now],v;e;e=edge[e].nex)
if(vis[v=edge[e].to])l=max(l,merge(now,v));
ans=max(ans,l*nod[i].v);
}
printf("%lld\n",ans);
return 0;
}