Count Valid Paths in a Tree

Count Valid Paths in a Tree

There is an undirected tree with n nodes labeled from 1 to n. You are given the integer n and a 2D integer array edges of length n - 1, where edges[i] = [ui, vi] indicates that there is an edge between nodes ui and vi in the tree.

Return the number of valid paths in the tree.

A path (a, b) is valid if there exists exactly one prime number among the node labels in the path from a to b.

Note that:

  • The path (a, b) is a sequence of distinct nodes starting with node a and ending with node b such that every two adjacent nodes in the sequence share an edge in the tree.
  • Path (a, b) and path (b, a) are considered the same and counted only once.

 

Example 1:

Input: n = 5, edges = [[1,2],[1,3],[2,4],[2,5]]
Output: 4
Explanation: The pairs with exactly one prime number on the path between them are: 
- (1, 2) since the path from 1 to 2 contains prime number 2. 
- (1, 3) since the path from 1 to 3 contains prime number 3.
- (1, 4) since the path from 1 to 4 contains prime number 2.
- (2, 4) since the path from 2 to 4 contains prime number 2.
It can be shown that there are only 4 valid paths.

Example 2:

Input: n = 6, edges = [[1,2],[1,3],[2,4],[3,5],[3,6]]
Output: 6
Explanation: The pairs with exactly one prime number on the path between them are: 
- (1, 2) since the path from 1 to 2 contains prime number 2.
- (1, 3) since the path from 1 to 3 contains prime number 3.
- (1, 4) since the path from 1 to 4 contains prime number 2.
- (1, 6) since the path from 1 to 6 contains prime number 3.
- (2, 4) since the path from 2 to 4 contains prime number 2.
- (3, 6) since the path from 3 to 6 contains prime number 3.
It can be shown that there are only 6 valid paths.

 

Constraints:

  • 1 <= n <= 105
  • edges.length == n - 1
  • edges[i].length == 2
  • 1 <= ui, vi <= n
  • The input is generated such that edges represent a valid tree.

 

解题思路

  把简单路径按照经过的最高点进行分类,即根据路径两个端点的最近公共祖先来分类。对于以 $u$ 为根的子树,假设与其相邻的子节点为 $v_i$,由于经过 $u$ 的路径中必须恰好有一个质数节点,因此在子树 $v_i$ 中,含端点 $v_i$ 的路径中最多只能有一个质数节点。

  对此定义 $f(u, 0/1)$ 表示在子树 $u$ 中含端点 $u$ 的路径中恰好有 $0/1$ 个质数节点的数量。当 $u$ 是质数节点时,有 $\displaylines{\begin{cases} f(u,0) = 0 \\ f(u,1) = 1 + \sum{f(v_i,0)} \end{cases}}$。当 $u$ 不是质数节点时,有 $\displaylines{\begin{cases} f(u,0) = 1 + \sum{f(v_i,0)} \\ f(u,1) = \sum{f(v_i,1)} \end{cases}}$。

  再考虑两个端点的最近公共祖先是 $u$ 且经过恰好一个质数节点的简单路径的数量。如果 $u$ 是质数节点,则分成两个部分,一个是端点含 $u$ 的,对应的数量是 $\sum{f(v_i,0)}$。另一个两个端点来自不同的子树 $v_i$,对应的数量是 $\sum{f(v_i,0) \sum\limits_{j>i}{f(v_j,0)}}$。如果 $u$ 不是质数节点,也分成两个部分,一个是端点含 $u$ 的,对应的数量是 $\sum{f(v_i,1)}$。另一个两个端点来自不同的子树 $v_i$,对应的数量是 $\sum{f(v_i,1) \sum\limits_{j \ne i}{f(v_j,0)}}$。

  AC 代码如下,时间复杂度为 $O(n)$:

class Solution {
public:
    long long countPaths(int n, vector<vector<int>>& edges) {
        vector<vector<int>> g(n + 1);
        for (auto &p : edges) {
            g[p[0]].push_back(p[1]);
            g[p[1]].push_back(p[0]);
        }
        vector<int> primes;
        vector<bool> vis(n + 1);
        vis[1] = true;
        for (int i = 2; i <= n; i++) {
            if (!vis[i]) primes.push_back(i);
            for (int j = 0; primes[j] * i <= n; j++) {
                vis[primes[j] * i] = true;
                if (i % primes[j] == 0) break;
            }
        }
        vector<vector<int>> f(n + 1, vector<int>(2));
        long long ans = 0;
        function<void(int, int)> dfs = [&](int u, int p) {
            int s0 = 0, s1 = 0;
            for (auto &v : g[u]) {
                if (v == p) continue;
                dfs(v, u);
                s0 += f[v][0];
                s1 += f[v][1];
            }
            if (!vis[u]) {
                f[u][0] = 0;
                f[u][1] = s0 + 1;
                ans += s0;
                for (auto &v : g[u]) {
                    if (v == p) continue;
                    s0 -= f[v][0];
                    ans += 1ll * f[v][0] * s0;
                }
            }
            else {
                f[u][0] = s0 + 1;
                f[u][1] = s1;
                ans += s1;
                for (auto &v : g[u]) {
                    if (v == p) continue;
                    ans += 1ll * f[v][1] * (s0 - f[v][0]);
                }
            }
        };
        dfs(1, -1);
        return ans;
    }
};
posted @ 2024-03-10 20:43  onlyblues  阅读(6)  评论(0编辑  收藏  举报
Web Analytics