BZOJ 1912(树的直径+LCA)
题面
分析
显然,如果不加边,每条边都要走2次,总答案为2(n-1)
考虑k=1的朴素情况:
加一条边(a,b),这条边和树上a->b的路径形成一个环,这个环上的边只需要走一遍,所以答案会减少dist(a,b)-1 (a->b的路径少走一边,但是又多加了一条边,最终答案为2*(n-1)-dist(a,b)+1
显然dist(a,b)应该最大,求树上的直径\(L_1\)即可
接着推广到k=2的情况
加第二条边时,又会形成一个环,这个环和第一条边形成的环可能有重叠,重叠部分要走两遍,答案又会增加
因此,我们把直径上的边权取反(1变为-1),在直径取反后的树上再次求直径,设直径为\(L_2\)
答案就是\(2(n-1)-(L_1-1)-(L_2-1)\)
注意两次BFS求直径的方法在负权图上求直径不成立,需要用树形DP
具体方法如下:
1.在原图上用两次BFS求出直径\(L_1\)端点a,b
2.通过LCA和树上差分将直径上的边标记,并取反边权
3.用树形DP求出新直径\(L_2\)
代码
#include<iostream>
#include<cstdio>
#include<queue>
#include<cstring>
#include<cmath>
#define maxlog 32
#define maxn 400005
#define INF 0x7fffffff
using namespace std;
int n,k;
struct edge {
int from;
int to;
int next;
int len;
} E[maxn*2];
int sz=1;
int head[maxn];
void add_edge(int u,int v,int w) {
sz++;
E[sz].from=u;
E[sz].to=v;
E[sz].len=w;
E[sz].next=head[u];
head[u]=sz;
}
namespace k_1 {
struct node {
int x;
int t;
node() {
}
node(int u,int dis) {
x=u;
t=dis;
}
};
int vis[maxn];
node bfs(int s) {
node now,ans;
queue<node>q;
q.push(node(s,0));
memset(vis,0,sizeof(vis));
ans.t=-INF;
while(!q.empty()) {
now=q.front();
q.pop();
int x=now.x;
vis[x]=1;
if(now.t>ans.t) {
ans.t=now.t;
ans.x=now.x;
}
for(int i=head[x]; i; i=E[i].next) {
int y=E[i].to;
if(!vis[y]) {
q.push(node(y,now.t+1));
}
}
}
return ans;
}
void solve() {
node p=bfs(1);
node q=bfs(p.x);
printf("%d\n",2*(n-1)-q.t+1);
}
}
namespace k_2 {
int log2n;
struct node {
int x;
int t;
node() {
}
node(int u,int dis) {
x=u;
t=dis;
}
};
int vis[maxn];
int mark[maxn];
node bfs(int s,int tim) {
node now,ans;
queue<node>q;
q.push(node(s,0));
memset(vis,0,sizeof(vis));
ans.t=-INF;
ans.x=s;
if(n==1) return node(1,0);
while(!q.empty()) {
now=q.front();
q.pop();
int x=now.x;
vis[x]=1;
if(tim==1) {
if(now.t>=ans.t&&now.x!=s) {
ans.t=now.t;
ans.x=now.x;
}
}else{
if(now.t>=ans.t) {
ans.t=now.t;
ans.x=now.x;
}
}
for(int i=head[x]; i; i=E[i].next) {
int y=E[i].to;
if(!vis[y]) {
q.push(node(y,now.t+E[i].len));
}
}
}
return ans;
}
int anc[maxn][maxlog];
int deep[maxn];
void dfs1(int x,int fa) {
deep[x]=deep[fa]+1;
anc[x][0]=fa;
for(int i=1; i<=log2n; i++) {
anc[x][i]=anc[anc[x][i-1]][i-1];
}
for(int i=head[x]; i; i=E[i].next) {
int y=E[i].to;
if(y!=fa) {
dfs1(y,x);
}
}
}
void dfs2(int x,int fa) {
for(int i=head[x]; i; i=E[i].next) {
int y=E[i].to;
if(y!=fa) {
dfs2(y,x);
mark[x]+=mark[y];
}
}
}
int lca(int x,int y) {
if(deep[x]<deep[y]) swap(x,y);
for(int i=log2n; i>=0; i--) {
if(deep[anc[x][i]]>=deep[y]) x=anc[x][i];
}
if(x==y) return x;
for(int i=log2n; i>=0; i--) {
if(anc[x][i]!=anc[y][i]) {
x=anc[x][i];
y=anc[y][i];
}
}
return anc[x][0];
}
int len=0;
int d[maxn];
void dp(int x,int fa){
d[x]=0;
for(int i=head[x];i;i=E[i].next){
int y=E[i].to;
if(y!=fa){
dp(y,x);
len=max(len,d[x]+d[y]+E[i].len);
d[x]=max(d[x],d[y]+E[i].len);
}
}
}
void solve() {
int u,v;
int ans=2*(n-1);
log2n=log2(n)+1;
dfs1(1,0);
node p=bfs(1,1);
node q=bfs(p.x,2);
ans=ans-q.t+1;
memset(mark,0,sizeof(mark));
mark[p.x]++;
mark[q.x]++;
mark[lca(p.x,q.x)]-=2;
dfs2(1,0);
memset(head,0,sizeof(head));
memset(E,0,sizeof(E));
sz=1;
for(int i=2; i<=n; i++) {
add_edge(i,anc[i][0],mark[i]==1?-1:1);
add_edge(anc[i][0],i,mark[i]==1?-1:1);
}
memset(d,-0x3f,sizeof(d));
dp(1,0);
ans=ans-len+1;
printf("%d\n",ans);
}
}
int main() {
int u,v;
scanf("%d %d",&n,&k);
sz=1;
for(int i=1; i<n; i++) {
scanf("%d %d",&u,&v);
add_edge(u,v,1);
add_edge(v,u,1);
}
if(k==1) k_1::solve();
else k_2::solve();
}
版权声明:因为我是蒟蒻,所以请大佬和神犇们不要转载(有坑)的文章,并指出问题,谢谢