UOJ #11 - 【UTR #1】ydc的大树(换根 dp)
Emmm……这题似乎做法挺多的,那就提供一个想起来写起来都不太困难的做法吧。
首先不难想到一个时间复杂度 O(n2)O(n2) 的做法:对于每个黑点我们以它为根求出离它距离最远的点集 S,那么一个白点能够摧毁这个黑点当且仅当这个白点在黑点到点集 S 中的点的 LCA 的路径上。这样我们就可以求出所有白点的答案了。
考虑优化这个过程,注意“以每个点为根”一脸可用换根 dp 优化的亚子,因此考虑换根 dp,如果单纯地求到每个点距离最远的黑点那你肯定会求的欸(yyq 附体),直接一遍常规的换根 dp 就完事了。不过此题还需求 LCA,因此考虑以下做法,我们先以 1 为根一遍 DFS,对于每个点 x 求出以 x 为根的子树内里 x 最远的点的 LCA,然后我们再额外记录两个数组 dis_outx 表示去掉以 x 为根的子树内剩余部分离 x 最远的黑点离 x 的距离,以及 lca_outi 表示它们的 LCA,怎样求这两个数组呢?就按照第二遍 DFS 的套路从上往下更新,当 DFS 到 x 时将 x 的每个子树的信息压入一个 multiset
,并枚举 x 的每个儿子 y,将 y 的信息从 multiset
中暂时删除,如果 multiset
中距离最大值和次大值相等那么 lca_outy 就是 x,否则 lca_outy 就是 multiset
中最大值对应的 LCA,最后更新答案时对于每个 x 看它子树内和子树外黑点距离其的最大值,哪边大那个 LCA 就属于那边,如果相等那么 LCA 就是 x。
最后统计答案树上差分即可,时间复杂度 nlogn。
const int MAXN=1e5; const int LOG_N=17; const int INF=0x3f3f3f3f; int n,m,hd[MAXN+5],to[MAXN*2+5],val[MAXN*2+5],nxt[MAXN*2+5],ec=0;bool is[MAXN+5]; void adde(int u,int v,int w){to[++ec]=v;val[ec]=w;nxt[ec]=hd[u];hd[u]=ec;} int fa[MAXN+5][LOG_N+2],dep[MAXN+5],dis[MAXN+5],lca_in[MAXN+5]; int dis_out[MAXN+5],lca_out[MAXN+5]; int mark[MAXN+5]; void dfs1(int x=1,int f=0){ dis[x]=(is[x]^1)*(-INF);lca_in[x]=x;fa[x][0]=f; for(int e=hd[x];e;e=nxt[e]){ int y=to[e],z=val[e];if(y==f) continue; dep[y]=dep[x]+1;dfs1(y,x); if(dis[y]+z>dis[x]) dis[x]=dis[y]+z,lca_in[x]=lca_in[y]; else if(dis[y]+z==dis[x]) lca_in[x]=x; } /*printf("%d %d %d\n",x,dis[x],lca_in[x]);*/ } int getlca(int x,int y){ if(dep[x]<dep[y]) x^=y^=x^=y; for(int i=LOG_N;~i;i--) if(dep[x]-(1<<i)>=dep[y]) x=fa[x][i]; if(!(x^y)) return x; for(int i=LOG_N;~i;i--) if(fa[x][i]^fa[y][i]) x=fa[x][i],y=fa[y][i]; return fa[x][0]; } void dfs2(int x=1,int f=0){ multiset<pii> st; for(int e=hd[x];e;e=nxt[e]){ int y=to[e],z=val[e]; if(y==f) st.insert(mp(dis_out[x],lca_out[x])); else st.insert(mp(dis[y]+z,lca_in[y])); } if(is[x]) st.insert(mp(0,x)); for(int e=hd[x];e;e=nxt[e]){ int y=to[e],z=val[e];if(y==f) continue; st.erase(st.find(mp(dis[y]+z,lca_in[y]))); if(st.empty()) dis_out[y]=(is[x])?z:(-INF),lca_out[y]=(is[x])?x:0; else{ // printf("Node %d\n",y); pii p=*st.rbegin();st.erase(st.find(p));dis_out[y]=z+p.fi; // printf("%d %d\n",p.fi,p.se); if(st.empty()||(*st.rbegin()).fi<p.fi) lca_out[y]=p.se; else lca_out[y]=x;st.insert(p); } st.insert(mp(dis[y]+z,lca_in[y]));dfs2(y,x); } } void dfs3(int x=1,int f=0){ for(int e=hd[x];e;e=nxt[e]){ int y=to[e];if(y==f) continue; dfs3(y,x);mark[x]+=mark[y]; } } int main(){ scanf("%d%d",&n,&m);for(int i=1,x;i<=m;i++) scanf("%d",&x),is[x]=1; for(int i=1,u,v,w;i<n;i++) scanf("%d%d%d",&u,&v,&w),adde(u,v,w),adde(v,u,w);dis_out[1]=-INF; dfs1();for(int i=1;i<=LOG_N;i++) for(int j=1;j<=n;j++) fa[j][i]=fa[fa[j][i-1]][i-1];dfs2(); for(int i=1;i<=n;i++) if(is[i]){ int y=(dis[i]==dis_out[i])?i:((dis[i]<dis_out[i])?lca_out[i]:lca_in[i]); // printf("%d %d %d %d\n",i,dis[i],dis_out[i],y); int z=getlca(i,y);mark[i]++;mark[y]++;mark[z]--;mark[fa[z][0]]--; } dfs3();int mx=0,cnt=0; for(int i=1;i<=n;i++) if(!is[i]) chkmax(mx,mark[i]); for(int i=1;i<=n;i++) cnt+=(!is[i]&&mark[i]==mx); printf("%d %d\n",mx,cnt); return 0; }
【推荐】还在用 ECharts 开发大屏?试试这款永久免费的开源 BI 工具!
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步