树的合并 connect
树的合并 connect
题目描述
话说moreD经过不懈努力,终于背完了循环整数,也终于完成了他的蛋糕大餐。
但是不幸的是,moreD得到了诅咒,受到诅咒的原因至今无人知晓。
moreD在发觉自己得到诅咒之后,决定去寻找闻名遐迩的术士CD帮忙。
话说CD最近在搞OI,遇到了一道有趣的题目:
给定两棵树,则总共有N*M种方案把这两棵树通过加一条边连成一棵树,那这N*M棵树的直径大小之和是多少呢?
CD为了考验moreD是否值得自己费心力为他除去诅咒,于是要他编程回答这个问题,但是这moreD早就被诅咒搞晕了头脑,就只好请你帮助他了。
输入
第一行两个正整数N,M,分别表示两棵树的大小。
接下来N-1行,每行两个正整数ai,bi,表示第一棵树上的边。
接下来M-1行,每行两个正整数ci,di,表示第二棵树上的边。
输出
一行一个整数,表示答案。
样例输入
4 3 1 2 2 3 2 4 1 3 2 3
样例输出
53
提示
【数据范围】
对于20%的数据满足N<=300,M<=300
对于50%的数据满足N,M<=3000
对于100%的数据满足N<=10^5,M<=10^5,1<=ai,bi<=N,1<=ci,di<=M
【提示】
树的直径指的是树上的最长简单路径。
solution
预处理a[i]表示树a上以i开头的最长链
同理b[i]表示树b上以i开头的最长链
f[i]为i向下的最长链,dp[i]为向上的最长链
那么
f[i]用树形DP可以求出
考虑求dp[i]
g[i]为i向下的不与最长链的共有子节点的最长的链
举例:son[i]为i向下最长链的子节点 g[i] 就是除了son[k]外所有子节点向下的最长链+1
表述不清。。。
可以自己想想怎么求a,b
令Max表示A树和B树的直径的Max
排序统计即可
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
#define maxn 100005
using namespace std;
int n,m,tot,head[maxn],f[maxn],g[maxn],dp[maxn],t1,t2;
long long a[maxn],b[maxn],sum[maxn],ans,ma;
struct node{
int v,nex;
}e[maxn*2];
void lj(int t1,int t2){
tot++;e[tot].v=t2;e[tot].nex=head[t1];head[t1]=tot;
}
void dfs1(int k,int fa){
int Max=-1e9,max2=-1e9;
for(int i=head[k];i;i=e[i].nex){
if(e[i].v!=fa){
dfs1(e[i].v,k);
if(f[e[i].v]>Max){
max2=max(max2,Max);
Max=f[e[i].v];
}
else max2=max(max2,f[e[i].v]);
}
}
if(Max==-1e9)f[k]=0,g[k]=-1e9;
else {f[k]=Max+1;g[k]=max2+1;}
//cout<<k<<' '<<f[k]<<' '<<g[k]<<endl;
}
void dfs2(int k,int fa){
for(int i=head[k];i;i=e[i].nex){
if(e[i].v!=fa){
dp[e[i].v]=max(dp[e[i].v],dp[k]+1);
if(f[e[i].v]==f[k]-1){
dp[e[i].v]=max(dp[e[i].v],g[k]+1);
}
else dp[e[i].v]=max(dp[e[i].v],f[k]+1);
dfs2(e[i].v,k);
}
}
}
void Q(){
for(int i=1;i<=n;i++)head[i]=f[i]=g[i]=dp[i]=0;
tot=0;
}
int main()
{
cin>>n>>m;
for(int i=1;i<n;i++){
scanf("%d%d",&t1,&t2);
lj(t1,t2);lj(t2,t1);
}
dfs1(1,0);dfs2(1,0);
for(int i=1;i<=n;i++)a[i]=max(f[i],dp[i]);
Q();
for(int i=1;i<m;i++){
scanf("%d%d",&t1,&t2);
lj(t1,t2);lj(t2,t1);
}
dfs1(1,0);dfs2(1,0);
for(int i=1;i<=m;i++){
b[i]=max(f[i],dp[i]);
}
sort(a+1,a+n+1);sort(b+1,b+m+1);
for(int i=1;i<=n;i++)sum[i]=sum[i-1]+a[i];
ma=max(a[n],b[m]);
int l=1;
for(int i=m;i>=1;i--){
while(b[i]+a[l]+1<ma&&l<=n)l++;
long long num=n-l+1;
long long tmp=b[i]*num;tmp+=sum[n]-sum[l-1];tmp+=num;
tmp+=ma*(n-num);
ans+=tmp;
}
cout<<ans<<endl;
return 0;
}