基础树上问题之 树的直径 + 最近公共祖先 例题及学习笔记(入门版)

本篇博客是关于洛谷题单【图论2-1】基础树上问题 的题目题解合集
紫题还不会,先鸽
同时附加一点我的个人学习心得

基础树上问题 除了 树形dp 外,还有 树的直径LCA 等问题

树的直径

树的直径即树上最长路的长度

求法是首先任取一点作为根,求出一个到根最远的点,此为直径的一端;再以这个端点为根再进行一次dfs,求到根最远的点,为直径的另一端点

先放个树的直径的板子:

树的直径
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 1e5 + 10;
int n, k;
vector<int>e[N];
int d[N];   //点的实际深度
int maxd[N];//点可以到达的最大深度
int s, t, mxd;
int f[N], ans[N]; //到其他点的最大距离

void dfs1(int now, int fa) {
    d[now] = d[fa] + 1;
    if(d[now] > mxd){
        mxd = d[now];
        s = now;
    }
    for(auto i:e[now]) {
        if(i == fa) continue;
        dfs1(i, now);
    }
}

void dfs2(int now, int fa) {
    d[now] = d[fa] + 1;
    f[now] = fa;
    if(d[now] > mxd){
        mxd = d[now];
        t = now;
    }
    for(auto i:e[now]) {
        if(i == fa) continue;
        dfs2(i, now);
    }
}

void solve() {
    //两次dfs求直径
    mxd = -1;
    dfs1(1, 0);  
    d[0] = -1; mxd = -1;
    dfs2(s, 0);
    //s 和 t 即为树的直径
}

int main(){
    cin >> n >> k;
    for(int i = 1, u, v; i < n; i++) {
        scanf("%d%d", &u, &v);
        e[u].pb(v);
        e[v].pb(u);
    }
    solve();
    system("pause");
    return 0;
}

----------------接下来是例题-----------------------

P1395 会议

题意
求到n个人距离之和最小的树上的点

思路
其实就是先任选一个点,求出距离,可以 \(O(n)\) 更新其他的点

点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 5e5+10, mod = 998244353;
int t, n, q, u, v;
struct Edge{
	int to,nex;
}e[2*N]; 
int head[N],d[N],fa[N][30], sz[N];
ll f[N], mx;
int ind, cnt;

void add(int u,int v){
	e[++cnt].to=v;
	e[cnt].nex=head[u];
	head[u]=cnt;
} 

void dfs(int now,int father){
	sz[now] = 1;
    d[now] = d[father] + 1;
	for(int i=head[now];i;i=e[i].nex){
		if(e[i].to!=father){
            dfs(e[i].to,now);
            sz[now] += sz[e[i].to];
        }
	}
}

void dfs2(int now,int father){
    f[now] = f[father] - sz[now] + (n - sz[now]);
	for(int i=head[now];i;i=e[i].nex){
        int x = e[i].to;
		if(x != father){
            dfs2(x, now);
        }
	}
}

int main(){
    scanf("%d", &n);
    for (int i = 1; i < n; i++) {
        scanf("%d%d", &u, &v);
        add(u, v);
        add(v, u);
    }
    dfs(1, 0);
    
    f[1] = 0;
    for (int j = 1; j <= n; j++ ) {
        f[1] += d[j] - d[1];
    }
    mx = f[1];
    ind = 1;

    for (int i = head[1]; i; i = e[i].nex) {
        dfs2(e[i].to, 1);
    }

    for (int i = 1; i <= n; i++) {
        if(mx > f[i]){
            mx = f[i];
            ind = i;
        }
    }

    cout<<ind<<' '<<mx<<endl;

    system("pause");
    return 0;
}

P5536 【XR-3】核心城市

题意
选k个不经过其他城市就两两可达的点作为核心城市,求非核心城市到核心城市的最大距离的最小值

思路
如果 $k = 1 $ ,那这个城市就是树的直径的中点
如果 $k > 1 $ ,先找到第一个核心城市,然后从这个城市开始dfs,贪心地选取剩下的城市。具体见代码

点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 1e5 + 10;
int n, k;
vector<int>e[N];
int d[N];   //点的实际深度
int maxd[N];//点可以到达的最大深度
int s, t, mxd;
int f[N], ans[N]; //到其他点的最大距离

void dfs1(int now, int fa) {
    d[now] = d[fa] + 1;
    if(d[now] > mxd){
        mxd = d[now];
        s = now;
    }
    for(auto i:e[now]) {
        if(i == fa) continue;
        dfs1(i, now);
    }
}

void dfs2(int now, int fa) {
    d[now] = d[fa] + 1;
    f[now] = fa;
    if(d[now] > mxd){
        mxd = d[now];
        t = now;
    }
    for(auto i:e[now]) {
        if(i == fa) continue;
        dfs2(i, now);
    }
}

void dfs_k(int now, int fa) {
    d[now] = d[fa] + 1;
    maxd[now] = d[now];
    for(auto i:e[now]){
        if(i == fa) continue;
        dfs_k(i, now);
        maxd[now] = max(maxd[now], maxd[i]);
    }
}

void solve() {
    //两次dfs求直径
    mxd = -1;
    dfs1(1, 0);  
    d[0] = -1; mxd = -1;
    dfs2(s, 0);

    //找直径中点t
    int tt = t;
    for(int i = 1; i <= (d[tt] - d[s]) / 2 ; i++) t = f[t];

    //确定k个点 , 首先求出每个点能到达(往下走)的最大深度
    d[0] = -1;
    dfs_k(t, 0);
    for(int i = 1; i <= n; i++) {
        // cout<<i<<' '<<d[i]<<' '<<maxd[i]<<endl; ///
        ans[i] = maxd[i] - d[i];
    }
    sort(ans + 1, ans + n + 1, greater<int>());
    printf("%d\n", ans[k + 1] + 1);
}

int main(){
    cin >> n >> k;
    for(int i = 1, u, v; i < n; i++) {
        scanf("%d%d", &u, &v);
        e[u].pb(v);
        e[v].pb(u);
    }
    solve();
    system("pause");
    return 0;
}

P1099 [NOIP2007 提高组] 树网的核

题意

给定一棵树和一个距离s,你需要找到一段树的直径上的长度不超过s的线段作为树网的核,使得其他点到这个树网的核的距离的最大值最小

思路

首先可以看到 \(n <= 100\) ,于是可以采用 \(O(n^2)\) 做法,枚举直径上每一段长度 $ <= s$ 的线段,然后求 \(ans\)

\(ans\) 的求法可以分为,直径上的点到线段的距离,和直径外的点到线段的距离

直径上的点的最大距离肯定是到两个端点的较大值,直径外的点只需要求出到直径上每个点的最小值即可

点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 1e5 + 10;
int n, S;
vector<pii>e[N];
int d[N];   //点的实际深度
int s, t, mxd;
int f[N];
bool vis[N];  //直径上的点

void dfs1(int now, int fa) {
    if(d[now] > mxd){
        mxd = d[now];
        s = now;
    }
    for(auto i:e[now]) {
        if(i.first == fa) continue;
        d[i.first] = d[now] + i.second;
        dfs1(i.first, now);
    }
}

void dfs2(int now, int fa) {
    if(d[now] > mxd){
        mxd = d[now];
        t = now;
    }
    for(auto i:e[now]) {
        if(i.first == fa) continue;
        d[i.first] = d[now] + i.second;
        f[i.first] = now;
        dfs2(i.first, now);
    }
}

void solve() {
    //两次dfs求直径
    mxd = -1; d[1] = 0;
    dfs1(1, 1);  
    mxd = -1; d[s] = 0;
    dfs2(s, s);
    f[s] = 0;

    int ans = 1e9;
    //答案第一种来源:直径上的
    for(int i = t; i; i = f[i]){
        vis[i] = 1;
        for(int j = i; j; j = f[j]){
            if(d[i] - d[j] <= S){
                ans = min(ans, max(d[j], d[t] - d[i]));
            }
        }
    }
    // printf("%d\n", ans);

    //答案另外一种来源:直径之外的
    for(int j = 1; j <= n; j++){
        if(vis[j]) continue;
        int mx = 1e9;
        for(int i = t; i; i = f[i]) {
            if(d[j] > d[i]) mx = min(mx, d[j] - d[i]);
        }
        ans = max(ans, mx);
    }

    printf("%d\n", ans);
}

int main(){
    cin >> n >> S;
    for(int i = 1, u, v, w; i < n; i++) {
        scanf("%d%d%d", &u, &v, &w);
        e[u].pb({v,w});
        e[v].pb({u,w});
    }
    solve();
    system("pause");
    return 0;
}

最近公共祖先(LCA)

先放个LCA的板子,亲测能通过洛谷上LCA相关的题目

LCA

#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 1e5 + 10;
int t, n, q;
vector<int> e[N];
int a[2], b[2];
int f[N][33], d[N];

void dfs(int now, int fa) {
	d[now] = d[fa] + 1;
	f[now][0] = fa;
	for(int i = 1; (1 << i) <= d[now]; i++) {
		f[now][i] = f[f[now][i - 1]][i - 1];
	}
	for(auto i:e[now]) {
		if(i == fa) continue;
		dfs(i, now);
	}
}

int lca(int a, int b) {
	if(d[a] < d[b]) swap(a, b);
	int dep;
	for(dep = 0; (1 << dep) <= d[a]; dep++) ; dep--;
	for(int i = dep; i >= 0 ; i--) {
		if(d[a] - (1 << i) >= d[b]) a = f[a][i];
	}
	if(a == b) return a;
	for(int i = dep; i >= 0; i--) {
		if(f[a][i] == f[b][i]) continue;
		else {
			a = f[a][i];
			b = f[b][i];
		}
	}
	return f[a][0];
}

inline int dis(int a, int b) {
	return d[a] + d[b] - 2 * d[lca(a, b)];
}

inline bool check(int a, int b, int ff) {
	if(dis(a, ff) + dis(b, ff) == dis(a, b)) return 1;
	return 0;
}

int main(){
    cin >> n >> q;
	for(int i = 1, u, v; i < n; i++) {
		scanf("%d%d", &u, &v);
		e[u].pb(v);
		e[v].pb(u);
	}

	dfs(1, 0);
	while(q--) {
		scanf("%d%d%d%d", &a[0], &b[0], &a[1], &b[1]);
		int f1 = lca(a[0], b[0]); int low1 = max(d[a[0]], d[b[0]]);
		int f2 = lca(a[1], b[1]); int low2 = max(d[a[1]], d[b[1]]);
		int f = lca(f1, f2);
		// cout<< f1 <<' '<<f2<<endl; ///
		if(check(a[0], b[0], f2) || check(a[1], b[1], f1) ) puts("Y");
		else puts("N");
	}
    system("pause");
    return 0;
}

----------------接下来是例题-----------------------

P5836 [USACO19DEC]Milk Visits S

题意
一棵树上,每个点有一种品种的奶牛,总共有两种奶牛。
\(q\) 位客人要从 \(u\) 点到 \(v\) 点参观,问能否经过特定种类的奶牛。

思路
随便指定一个点为根,可以用 \(dfs\) 求出每个点到根这条路径上两种牛的数目,询问一条到祖先的路径上牛的数目只要用这个点的减去祖先的即可
对于每个询问求出 \(u\)\(lca(u,v)\)\(v\)\(lca(u,v)\) 上的牛的数目,大于零即puts("Y")

点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 1e5+10, mod = 998244353;
int t, n, q, u, v;
char s[N];
vector<int> g[N];
struct Edge{
	int to,nex;
}e[2*N]; 
int head[N],d[N],fa[N][30],num[N][2];
int cnt;
char ch;

void add(int u,int v){
	e[++cnt].to=v;
	e[cnt].nex=head[u];
	head[u]=cnt;
} 

void dfs(int now,int father){
	fa[now][0]=father;
	d[now]=d[father]+1;
    num[now][0] = num[father][0] + (s[now] == 'H');
    num[now][1] = num[father][1] + (s[now] == 'G');
	for(int i=1;(1<<i)<=d[now];i++){
		fa[now][i]=fa[fa[now][i-1]][i-1];
	}
	for(int i=head[now];i;i=e[i].nex){
		if(e[i].to!=father) dfs(e[i].to,now);
	}
}

int lca(int a,int b) {                                         //非常标准的lca查找{
    if(d[a]<d[b]) swap(a,b);    //d[a]大 
    int dep;
    for(dep=0;(1<<dep)<=d[a];dep++);
	dep--;
    for(int i=dep;i>=0;i--)
        if(d[a]-(1<<i)>=d[b])
            a=fa[a][i];             //先把b移到和a同一个深度
    if(a==b) return a;                 //特判,如果b上来和就和a一样了,那就可以直接返回答案了
    for(int i=dep;i>=0;i--){
        if(fa[a][i]==fa[b][i])
            continue;
        else
            a=fa[a][i],b=fa[b][i];           //A和B一起上移
    }
    return fa[a][0];            
}

int main(){
    scanf("%d%d", &n, &q);
    scanf("%s", s+1);
    for (int i = 1; i < n; i++) {
        scanf("%d%d", &u, &v);
        add(u, v);
        add(v, u);
    }
    dfs(1, 1);
    while (q--) {
        scanf("%d%d", &u, &v); cin>>ch;
        int f = lca(u, v);
        int ans;
        if(ch == 'H') ans = num[u][0] + num[v][0] - num[f][0] - num[fa[f][0]][0];
        else ans = num[u][1] + num[v][1] - num[f][1] - num[fa[f][0]][1];
        if(ans) printf("1");
        else printf("0");
    }
    system("pause");
    return 0;
}

P3398 仓鼠找 sugar

题意

\(a\)\(b\) 和 从 \(c\)\(d\) 两段路径上,判断是否存在某点使得两段路径相交

思路

假设存在某一点在两条路径上,只需要判断是否满足 \(lca(a,b)\) 在 从 \(c\)\(d\) 的路径上,或者 \(lca(c,d)\) 在 从 \(a\)\(b\) 的路径上

具体方法:如果 \(dis[lca(a,b)][c] + dis[lca(a,b)][d] == dis[c][d]\) 即可认为 \(lca(a,b)\) 在 从 \(c\)\(d\) 的路径上

点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 1e5 + 10;
int t, n, q;
vector<int> e[N];
int a[2], b[2];
int f[N][33], d[N];

//yes 的情况:一条路径的lca在另外一条路径上

//怎么知道一条路径上包含另外一个点?

void dfs(int now, int fa) {
	d[now] = d[fa] + 1;
	f[now][0] = fa;
	for(int i = 1; (1 << i) <= d[now]; i++) {
		f[now][i] = f[f[now][i - 1]][i - 1];
	}
	for(auto i:e[now]) {
		if(i == fa) continue;
		dfs(i, now);
	}
}

int lca(int a, int b) {
	if(d[a] < d[b]) swap(a, b);
	int dep;
	for(dep = 0; (1 << dep) <= d[a]; dep++) ; dep--;
	for(int i = dep; i >= 0 ; i--) {
		if(d[a] - (1 << i) >= d[b]) a = f[a][i];
	}
	if(a == b) return a;
	for(int i = dep; i >= 0; i--) {
		if(f[a][i] == f[b][i]) continue;
		else {
			a = f[a][i];
			b = f[b][i];
		}
	}
	return f[a][0];
}

inline int dis(int a, int b) {
	return d[a] + d[b] - 2 * d[lca(a, b)];
}

inline bool check(int a, int b, int ff) {
	if(dis(a, ff) + dis(b, ff) == dis(a, b)) return 1;
	return 0;
}

int main(){
    cin >> n >> q;
	for(int i = 1, u, v; i < n; i++) {
		scanf("%d%d", &u, &v);
		e[u].pb(v);
		e[v].pb(u);
	}

	dfs(1, 0);
	while(q--) {
		scanf("%d%d%d%d", &a[0], &b[0], &a[1], &b[1]);
		int f1 = lca(a[0], b[0]); int low1 = max(d[a[0]], d[b[0]]);
		int f2 = lca(a[1], b[1]); int low2 = max(d[a[1]], d[b[1]]);
		int f = lca(f1, f2);
		// cout<< f1 <<' '<<f2<<endl; ///
		if(check(a[0], b[0], f2) || check(a[1], b[1], f1) ) puts("Y");
		else puts("N");
	}
    system("pause");
    return 0;
}

P4281 [AHOI2008]紧急集合 / 聚会

题意

一棵树上,每次询问给定3个点,问哪个点x到这三个点的距离之和是最小的。并求出这个最小距离

思路
如果3个点在同一个位置,答案就是这个点
如果3个点有两个在同一位置,答案是另一个点
否则答案是与其他两个lca不同的那个lca
最小距离为 \(dis[x][a] + dis[x][b] + dis[x][c]\)

点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 5e5 + 10;
int t, n, q, x;
vector<int> e[N];
int f[N][33], d[N];
int a, b, c;

void dfs(int now, int fa) {
	d[now] = d[fa] + 1;
	f[now][0] = fa;
	for(int i = 1; (1 << i) <= d[now]; i++) {
		f[now][i] = f[f[now][i - 1]][i - 1];
	}
	for(auto i:e[now]) {
		if(i == fa) continue;
		dfs(i, now);
	}
}

int lca(int a, int b) {
	if(d[a] < d[b]) swap(a, b);
	int dep;
	for(dep = 0; (1 << dep) <= d[a]; dep++) ; dep--;
	for(int i = dep; i >= 0 ; i--) {
		if(d[a] - (1 << i) >= d[b]) a = f[a][i];
	}
	if(a == b) return a;
	for(int i = dep; i >= 0; i--) {
		if(f[a][i] == f[b][i]) continue;
		else {
			a = f[a][i];
			b = f[b][i];
		}
	}
	return f[a][0];
}

inline int dis(int a, int b) {
	return d[a] + d[b] - 2 * d[lca(a, b)];
}

int main(){
    cin >> n >> q;
	for(int i = 1, u, v; i < n; i++) {
		scanf("%d%d", &u, &v);
		e[u].pb(v);
		e[v].pb(u);
	}

	dfs(1, 0);
	
	while(q--) {
		scanf("%d%d%d", &a, &b, &c);
		if(a == b && b == c) {
			printf("%d %d\n", a, 0);
			continue;
		}
		if(a == b || b == c) {
			if(a == b) x = a;
			else if(c == b) x = b;
			else if(c == a) x = a;
		}
		else {
			int f1 = lca(a, b);
			int f2 = lca(a, c);
			int f3 = lca(b, c);
			if(f1 == f2) x = f3;
			else if(f1 == f3) x = f2;
			else if(f2 == f3) x = f1;
		}

		printf("%d %d\n", x, dis(a,x) + dis(b,x) + dis(c,x));
	}
    system("pause");
    return 0;
}

P5588 小猪佩奇爬树

题意

一棵树,每个点有一个颜色,求包含每种颜色的线段的种数

思路
分类讨论:
如果没有这个颜色的点,答案是 \(n * (n - 1) / 2\)
有1个点,dfs求出
有多个点,如果经过那些点的树有两个叶结点,那么就有答案
否则没有答案,因为不可能有线段可以经过分3叉的树

点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 1e6 + 10;
int t, n, q, x;
vector<int> e[N];
int a[N], tot[N], sz[N], cnt[N]; //子树里颜色为i的点的个数
int cntend[N];  //颜色i的端点
ll ans[N], ans2[N];

void dfs(int now, int fa) {
	int color = a[now]; int k = cnt[color];
	sz[now] = 1;
	int flag = 0, pos = 0;
	for(auto i:e[now]) {
		if(i == fa) continue;
		int nowcnt = cnt[color];
		dfs(i, now);
		if(cnt[color] > nowcnt) {
			flag++;
			pos = i;
		}
		ans[color] += 1ll * sz[now] * sz[i];
		sz[now] += sz[i];
	}
	if(k || cnt[color] != tot[color] - 1) {
		flag++;
	}
	cnt[color]++;
	ans[color] += 1ll * sz[now] * (n - sz[now]);
	if(flag == 1) {
		cntend[color]++;
		if(ans2[color] == 0) ans2[color] = 1;
		int p = pos ? n - sz[pos] : sz[now];
		ans2[color] *= 1ll * p;
	}
}

int main(){
    cin >> n;
	for(int i = 1; i <= n; i++) scanf("%d", &a[i]), tot[a[i]]++;
	for(int i = 1, u, v; i < n; i++) {
		scanf("%d%d", &u, &v);
		e[u].pb(v);
		e[v].pb(u);
	}

	dfs(1, 0);

	for(int i = 1; i <= n; i++){
		if(tot[i] == 0) printf("%lld\n", 1ll * n * (n - 1) / 2);
		else if(tot[i] == 1) printf("%lld\n",ans[i]); 
		else{
			if(cntend[i] == 2) printf("%lld\n",ans2[i]); 
			else puts("0");
		}
	}puts("");

    system("pause");
    return 0;
}
posted @ 2022-08-22 16:53  starlightlmy  阅读(101)  评论(0编辑  收藏  举报