基础树上问题之 树的直径 + 最近公共祖先 例题及学习笔记(入门版)
本篇博客是关于洛谷题单【图论2-1】基础树上问题 的题目题解合集
紫题还不会,先鸽
同时附加一点我的个人学习心得
基础树上问题 除了 树形dp 外,还有 树的直径 和 LCA 等问题
树的直径
树的直径即树上最长路的长度
求法是首先任取一点作为根,求出一个到根最远的点,此为直径的一端;再以这个端点为根再进行一次dfs,求到根最远的点,为直径的另一端点
先放个树的直径的板子:
树的直径
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 1e5 + 10;
int n, k;
vector<int>e[N];
int d[N]; //点的实际深度
int maxd[N];//点可以到达的最大深度
int s, t, mxd;
int f[N], ans[N]; //到其他点的最大距离
void dfs1(int now, int fa) {
d[now] = d[fa] + 1;
if(d[now] > mxd){
mxd = d[now];
s = now;
}
for(auto i:e[now]) {
if(i == fa) continue;
dfs1(i, now);
}
}
void dfs2(int now, int fa) {
d[now] = d[fa] + 1;
f[now] = fa;
if(d[now] > mxd){
mxd = d[now];
t = now;
}
for(auto i:e[now]) {
if(i == fa) continue;
dfs2(i, now);
}
}
void solve() {
//两次dfs求直径
mxd = -1;
dfs1(1, 0);
d[0] = -1; mxd = -1;
dfs2(s, 0);
//s 和 t 即为树的直径
}
int main(){
cin >> n >> k;
for(int i = 1, u, v; i < n; i++) {
scanf("%d%d", &u, &v);
e[u].pb(v);
e[v].pb(u);
}
solve();
system("pause");
return 0;
}
----------------接下来是例题-----------------------
题意
求到n个人距离之和最小的树上的点
思路
其实就是先任选一个点,求出距离,可以 \(O(n)\) 更新其他的点
点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 5e5+10, mod = 998244353;
int t, n, q, u, v;
struct Edge{
int to,nex;
}e[2*N];
int head[N],d[N],fa[N][30], sz[N];
ll f[N], mx;
int ind, cnt;
void add(int u,int v){
e[++cnt].to=v;
e[cnt].nex=head[u];
head[u]=cnt;
}
void dfs(int now,int father){
sz[now] = 1;
d[now] = d[father] + 1;
for(int i=head[now];i;i=e[i].nex){
if(e[i].to!=father){
dfs(e[i].to,now);
sz[now] += sz[e[i].to];
}
}
}
void dfs2(int now,int father){
f[now] = f[father] - sz[now] + (n - sz[now]);
for(int i=head[now];i;i=e[i].nex){
int x = e[i].to;
if(x != father){
dfs2(x, now);
}
}
}
int main(){
scanf("%d", &n);
for (int i = 1; i < n; i++) {
scanf("%d%d", &u, &v);
add(u, v);
add(v, u);
}
dfs(1, 0);
f[1] = 0;
for (int j = 1; j <= n; j++ ) {
f[1] += d[j] - d[1];
}
mx = f[1];
ind = 1;
for (int i = head[1]; i; i = e[i].nex) {
dfs2(e[i].to, 1);
}
for (int i = 1; i <= n; i++) {
if(mx > f[i]){
mx = f[i];
ind = i;
}
}
cout<<ind<<' '<<mx<<endl;
system("pause");
return 0;
}
题意
选k个不经过其他城市就两两可达的点作为核心城市,求非核心城市到核心城市的最大距离的最小值
思路
如果 $k = 1 $ ,那这个城市就是树的直径的中点
如果 $k > 1 $ ,先找到第一个核心城市,然后从这个城市开始dfs,贪心地选取剩下的城市。具体见代码
点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 1e5 + 10;
int n, k;
vector<int>e[N];
int d[N]; //点的实际深度
int maxd[N];//点可以到达的最大深度
int s, t, mxd;
int f[N], ans[N]; //到其他点的最大距离
void dfs1(int now, int fa) {
d[now] = d[fa] + 1;
if(d[now] > mxd){
mxd = d[now];
s = now;
}
for(auto i:e[now]) {
if(i == fa) continue;
dfs1(i, now);
}
}
void dfs2(int now, int fa) {
d[now] = d[fa] + 1;
f[now] = fa;
if(d[now] > mxd){
mxd = d[now];
t = now;
}
for(auto i:e[now]) {
if(i == fa) continue;
dfs2(i, now);
}
}
void dfs_k(int now, int fa) {
d[now] = d[fa] + 1;
maxd[now] = d[now];
for(auto i:e[now]){
if(i == fa) continue;
dfs_k(i, now);
maxd[now] = max(maxd[now], maxd[i]);
}
}
void solve() {
//两次dfs求直径
mxd = -1;
dfs1(1, 0);
d[0] = -1; mxd = -1;
dfs2(s, 0);
//找直径中点t
int tt = t;
for(int i = 1; i <= (d[tt] - d[s]) / 2 ; i++) t = f[t];
//确定k个点 , 首先求出每个点能到达(往下走)的最大深度
d[0] = -1;
dfs_k(t, 0);
for(int i = 1; i <= n; i++) {
// cout<<i<<' '<<d[i]<<' '<<maxd[i]<<endl; ///
ans[i] = maxd[i] - d[i];
}
sort(ans + 1, ans + n + 1, greater<int>());
printf("%d\n", ans[k + 1] + 1);
}
int main(){
cin >> n >> k;
for(int i = 1, u, v; i < n; i++) {
scanf("%d%d", &u, &v);
e[u].pb(v);
e[v].pb(u);
}
solve();
system("pause");
return 0;
}
题意
给定一棵树和一个距离s,你需要找到一段树的直径上的长度不超过s的线段作为树网的核,使得其他点到这个树网的核的距离的最大值最小
思路
首先可以看到 \(n <= 100\) ,于是可以采用 \(O(n^2)\) 做法,枚举直径上每一段长度 $ <= s$ 的线段,然后求 \(ans\)
\(ans\) 的求法可以分为,直径上的点到线段的距离,和直径外的点到线段的距离
直径上的点的最大距离肯定是到两个端点的较大值,直径外的点只需要求出到直径上每个点的最小值即可
点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 1e5 + 10;
int n, S;
vector<pii>e[N];
int d[N]; //点的实际深度
int s, t, mxd;
int f[N];
bool vis[N]; //直径上的点
void dfs1(int now, int fa) {
if(d[now] > mxd){
mxd = d[now];
s = now;
}
for(auto i:e[now]) {
if(i.first == fa) continue;
d[i.first] = d[now] + i.second;
dfs1(i.first, now);
}
}
void dfs2(int now, int fa) {
if(d[now] > mxd){
mxd = d[now];
t = now;
}
for(auto i:e[now]) {
if(i.first == fa) continue;
d[i.first] = d[now] + i.second;
f[i.first] = now;
dfs2(i.first, now);
}
}
void solve() {
//两次dfs求直径
mxd = -1; d[1] = 0;
dfs1(1, 1);
mxd = -1; d[s] = 0;
dfs2(s, s);
f[s] = 0;
int ans = 1e9;
//答案第一种来源:直径上的
for(int i = t; i; i = f[i]){
vis[i] = 1;
for(int j = i; j; j = f[j]){
if(d[i] - d[j] <= S){
ans = min(ans, max(d[j], d[t] - d[i]));
}
}
}
// printf("%d\n", ans);
//答案另外一种来源:直径之外的
for(int j = 1; j <= n; j++){
if(vis[j]) continue;
int mx = 1e9;
for(int i = t; i; i = f[i]) {
if(d[j] > d[i]) mx = min(mx, d[j] - d[i]);
}
ans = max(ans, mx);
}
printf("%d\n", ans);
}
int main(){
cin >> n >> S;
for(int i = 1, u, v, w; i < n; i++) {
scanf("%d%d%d", &u, &v, &w);
e[u].pb({v,w});
e[v].pb({u,w});
}
solve();
system("pause");
return 0;
}
最近公共祖先(LCA)
先放个LCA的板子,亲测能通过洛谷上LCA相关的题目
LCA
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 1e5 + 10;
int t, n, q;
vector<int> e[N];
int a[2], b[2];
int f[N][33], d[N];
void dfs(int now, int fa) {
d[now] = d[fa] + 1;
f[now][0] = fa;
for(int i = 1; (1 << i) <= d[now]; i++) {
f[now][i] = f[f[now][i - 1]][i - 1];
}
for(auto i:e[now]) {
if(i == fa) continue;
dfs(i, now);
}
}
int lca(int a, int b) {
if(d[a] < d[b]) swap(a, b);
int dep;
for(dep = 0; (1 << dep) <= d[a]; dep++) ; dep--;
for(int i = dep; i >= 0 ; i--) {
if(d[a] - (1 << i) >= d[b]) a = f[a][i];
}
if(a == b) return a;
for(int i = dep; i >= 0; i--) {
if(f[a][i] == f[b][i]) continue;
else {
a = f[a][i];
b = f[b][i];
}
}
return f[a][0];
}
inline int dis(int a, int b) {
return d[a] + d[b] - 2 * d[lca(a, b)];
}
inline bool check(int a, int b, int ff) {
if(dis(a, ff) + dis(b, ff) == dis(a, b)) return 1;
return 0;
}
int main(){
cin >> n >> q;
for(int i = 1, u, v; i < n; i++) {
scanf("%d%d", &u, &v);
e[u].pb(v);
e[v].pb(u);
}
dfs(1, 0);
while(q--) {
scanf("%d%d%d%d", &a[0], &b[0], &a[1], &b[1]);
int f1 = lca(a[0], b[0]); int low1 = max(d[a[0]], d[b[0]]);
int f2 = lca(a[1], b[1]); int low2 = max(d[a[1]], d[b[1]]);
int f = lca(f1, f2);
// cout<< f1 <<' '<<f2<<endl; ///
if(check(a[0], b[0], f2) || check(a[1], b[1], f1) ) puts("Y");
else puts("N");
}
system("pause");
return 0;
}
----------------接下来是例题-----------------------
P5836 [USACO19DEC]Milk Visits S
题意
一棵树上,每个点有一种品种的奶牛,总共有两种奶牛。
有 \(q\) 位客人要从 \(u\) 点到 \(v\) 点参观,问能否经过特定种类的奶牛。
思路
随便指定一个点为根,可以用 \(dfs\) 求出每个点到根这条路径上两种牛的数目,询问一条到祖先的路径上牛的数目只要用这个点的减去祖先的即可
对于每个询问求出 \(u\) 到 \(lca(u,v)\)、 \(v\) 到 \(lca(u,v)\) 上的牛的数目,大于零即puts("Y")
点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 1e5+10, mod = 998244353;
int t, n, q, u, v;
char s[N];
vector<int> g[N];
struct Edge{
int to,nex;
}e[2*N];
int head[N],d[N],fa[N][30],num[N][2];
int cnt;
char ch;
void add(int u,int v){
e[++cnt].to=v;
e[cnt].nex=head[u];
head[u]=cnt;
}
void dfs(int now,int father){
fa[now][0]=father;
d[now]=d[father]+1;
num[now][0] = num[father][0] + (s[now] == 'H');
num[now][1] = num[father][1] + (s[now] == 'G');
for(int i=1;(1<<i)<=d[now];i++){
fa[now][i]=fa[fa[now][i-1]][i-1];
}
for(int i=head[now];i;i=e[i].nex){
if(e[i].to!=father) dfs(e[i].to,now);
}
}
int lca(int a,int b) { //非常标准的lca查找{
if(d[a]<d[b]) swap(a,b); //d[a]大
int dep;
for(dep=0;(1<<dep)<=d[a];dep++);
dep--;
for(int i=dep;i>=0;i--)
if(d[a]-(1<<i)>=d[b])
a=fa[a][i]; //先把b移到和a同一个深度
if(a==b) return a; //特判,如果b上来和就和a一样了,那就可以直接返回答案了
for(int i=dep;i>=0;i--){
if(fa[a][i]==fa[b][i])
continue;
else
a=fa[a][i],b=fa[b][i]; //A和B一起上移
}
return fa[a][0];
}
int main(){
scanf("%d%d", &n, &q);
scanf("%s", s+1);
for (int i = 1; i < n; i++) {
scanf("%d%d", &u, &v);
add(u, v);
add(v, u);
}
dfs(1, 1);
while (q--) {
scanf("%d%d", &u, &v); cin>>ch;
int f = lca(u, v);
int ans;
if(ch == 'H') ans = num[u][0] + num[v][0] - num[f][0] - num[fa[f][0]][0];
else ans = num[u][1] + num[v][1] - num[f][1] - num[fa[f][0]][1];
if(ans) printf("1");
else printf("0");
}
system("pause");
return 0;
}
题意
从\(a\) 到 \(b\) 和 从 \(c\) 到 \(d\) 两段路径上,判断是否存在某点使得两段路径相交
思路
假设存在某一点在两条路径上,只需要判断是否满足 \(lca(a,b)\) 在 从 \(c\) 到 \(d\) 的路径上,或者 \(lca(c,d)\) 在 从 \(a\) 到 \(b\) 的路径上
具体方法:如果 \(dis[lca(a,b)][c] + dis[lca(a,b)][d] == dis[c][d]\) 即可认为 \(lca(a,b)\) 在 从 \(c\) 到 \(d\) 的路径上
点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 1e5 + 10;
int t, n, q;
vector<int> e[N];
int a[2], b[2];
int f[N][33], d[N];
//yes 的情况:一条路径的lca在另外一条路径上
//怎么知道一条路径上包含另外一个点?
void dfs(int now, int fa) {
d[now] = d[fa] + 1;
f[now][0] = fa;
for(int i = 1; (1 << i) <= d[now]; i++) {
f[now][i] = f[f[now][i - 1]][i - 1];
}
for(auto i:e[now]) {
if(i == fa) continue;
dfs(i, now);
}
}
int lca(int a, int b) {
if(d[a] < d[b]) swap(a, b);
int dep;
for(dep = 0; (1 << dep) <= d[a]; dep++) ; dep--;
for(int i = dep; i >= 0 ; i--) {
if(d[a] - (1 << i) >= d[b]) a = f[a][i];
}
if(a == b) return a;
for(int i = dep; i >= 0; i--) {
if(f[a][i] == f[b][i]) continue;
else {
a = f[a][i];
b = f[b][i];
}
}
return f[a][0];
}
inline int dis(int a, int b) {
return d[a] + d[b] - 2 * d[lca(a, b)];
}
inline bool check(int a, int b, int ff) {
if(dis(a, ff) + dis(b, ff) == dis(a, b)) return 1;
return 0;
}
int main(){
cin >> n >> q;
for(int i = 1, u, v; i < n; i++) {
scanf("%d%d", &u, &v);
e[u].pb(v);
e[v].pb(u);
}
dfs(1, 0);
while(q--) {
scanf("%d%d%d%d", &a[0], &b[0], &a[1], &b[1]);
int f1 = lca(a[0], b[0]); int low1 = max(d[a[0]], d[b[0]]);
int f2 = lca(a[1], b[1]); int low2 = max(d[a[1]], d[b[1]]);
int f = lca(f1, f2);
// cout<< f1 <<' '<<f2<<endl; ///
if(check(a[0], b[0], f2) || check(a[1], b[1], f1) ) puts("Y");
else puts("N");
}
system("pause");
return 0;
}
题意
一棵树上,每次询问给定3个点,问哪个点x到这三个点的距离之和是最小的。并求出这个最小距离
思路
如果3个点在同一个位置,答案就是这个点
如果3个点有两个在同一位置,答案是另一个点
否则答案是与其他两个lca不同的那个lca
最小距离为 \(dis[x][a] + dis[x][b] + dis[x][c]\)
点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 5e5 + 10;
int t, n, q, x;
vector<int> e[N];
int f[N][33], d[N];
int a, b, c;
void dfs(int now, int fa) {
d[now] = d[fa] + 1;
f[now][0] = fa;
for(int i = 1; (1 << i) <= d[now]; i++) {
f[now][i] = f[f[now][i - 1]][i - 1];
}
for(auto i:e[now]) {
if(i == fa) continue;
dfs(i, now);
}
}
int lca(int a, int b) {
if(d[a] < d[b]) swap(a, b);
int dep;
for(dep = 0; (1 << dep) <= d[a]; dep++) ; dep--;
for(int i = dep; i >= 0 ; i--) {
if(d[a] - (1 << i) >= d[b]) a = f[a][i];
}
if(a == b) return a;
for(int i = dep; i >= 0; i--) {
if(f[a][i] == f[b][i]) continue;
else {
a = f[a][i];
b = f[b][i];
}
}
return f[a][0];
}
inline int dis(int a, int b) {
return d[a] + d[b] - 2 * d[lca(a, b)];
}
int main(){
cin >> n >> q;
for(int i = 1, u, v; i < n; i++) {
scanf("%d%d", &u, &v);
e[u].pb(v);
e[v].pb(u);
}
dfs(1, 0);
while(q--) {
scanf("%d%d%d", &a, &b, &c);
if(a == b && b == c) {
printf("%d %d\n", a, 0);
continue;
}
if(a == b || b == c) {
if(a == b) x = a;
else if(c == b) x = b;
else if(c == a) x = a;
}
else {
int f1 = lca(a, b);
int f2 = lca(a, c);
int f3 = lca(b, c);
if(f1 == f2) x = f3;
else if(f1 == f3) x = f2;
else if(f2 == f3) x = f1;
}
printf("%d %d\n", x, dis(a,x) + dis(b,x) + dis(c,x));
}
system("pause");
return 0;
}
题意
一棵树,每个点有一个颜色,求包含每种颜色的线段的种数
思路
分类讨论:
如果没有这个颜色的点,答案是 \(n * (n - 1) / 2\)
有1个点,dfs求出
有多个点,如果经过那些点的树有两个叶结点,那么就有答案
否则没有答案,因为不可能有线段可以经过分3叉的树
点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 1e6 + 10;
int t, n, q, x;
vector<int> e[N];
int a[N], tot[N], sz[N], cnt[N]; //子树里颜色为i的点的个数
int cntend[N]; //颜色i的端点
ll ans[N], ans2[N];
void dfs(int now, int fa) {
int color = a[now]; int k = cnt[color];
sz[now] = 1;
int flag = 0, pos = 0;
for(auto i:e[now]) {
if(i == fa) continue;
int nowcnt = cnt[color];
dfs(i, now);
if(cnt[color] > nowcnt) {
flag++;
pos = i;
}
ans[color] += 1ll * sz[now] * sz[i];
sz[now] += sz[i];
}
if(k || cnt[color] != tot[color] - 1) {
flag++;
}
cnt[color]++;
ans[color] += 1ll * sz[now] * (n - sz[now]);
if(flag == 1) {
cntend[color]++;
if(ans2[color] == 0) ans2[color] = 1;
int p = pos ? n - sz[pos] : sz[now];
ans2[color] *= 1ll * p;
}
}
int main(){
cin >> n;
for(int i = 1; i <= n; i++) scanf("%d", &a[i]), tot[a[i]]++;
for(int i = 1, u, v; i < n; i++) {
scanf("%d%d", &u, &v);
e[u].pb(v);
e[v].pb(u);
}
dfs(1, 0);
for(int i = 1; i <= n; i++){
if(tot[i] == 0) printf("%lld\n", 1ll * n * (n - 1) / 2);
else if(tot[i] == 1) printf("%lld\n",ans[i]);
else{
if(cntend[i] == 2) printf("%lld\n",ans2[i]);
else puts("0");
}
}puts("");
system("pause");
return 0;
}