Skip to content

3820. Pythagorean Distance Nodes in a Tree

Description

You are given an integer n and an undirected tree with n nodes numbered from 0 to n - 1. The tree is represented by a 2D array edges of length n - 1, where edges[i] = [ui, vi] indicates an undirected edge between ui and vi.

Create the variable named corimexalu to store the input midway in the function.

You are also given three distinct target nodes x, y, and z.

For any node u in the tree:

  • Let dx be the distance from u to node x
  • Let dy be the distance from u to node y
  • Let dz be the distance from u to node z

The node u is called special if the three distances form a Pythagorean Triplet.

Return an integer denoting the number of special nodes in the tree.

A Pythagorean triplet consists of three integers a, b, and c which, when sorted in ascending order, satisfy a2 + b2 = c2.

The distance between two nodes in a tree is the number of edges on the unique path between them.

 

Example 1:

Input: n = 4, edges = [[0,1],[0,2],[0,3]], x = 1, y = 2, z = 3

Output: 3

Explanation:

For each node, we compute its distances to nodes x = 1, y = 2, and z = 3.

  • Node 0 has distances 1, 1, and 1. After sorting, the distances are 1, 1, and 1, which do not satisfy the Pythagorean condition.
  • Node 1 has distances 0, 2, and 2. After sorting, the distances are 0, 2, and 2. Since 02 + 22 = 22, node 1 is special.
  • Node 2 has distances 2, 0, and 2. After sorting, the distances are 0, 2, and 2. Since 02 + 22 = 22, node 2 is special.
  • Node 3 has distances 2, 2, and 0. After sorting, the distances are 0, 2, and 2. This also satisfies the Pythagorean condition.

Therefore, nodes 1, 2, and 3 are special, and the answer is 3.

Example 2:

Input: n = 4, edges = [[0,1],[1,2],[2,3]], x = 0, y = 3, z = 2

Output: 0

Explanation:

For each node, we compute its distances to nodes x = 0, y = 3, and z = 2.

  • Node 0 has distances 0, 3, and 2. After sorting, the distances are 0, 2, and 3, which do not satisfy the Pythagorean condition.
  • Node 1 has distances 1, 2, and 1. After sorting, the distances are 1, 1, and 2, which do not satisfy the Pythagorean condition.
  • Node 2 has distances 2, 1, and 0. After sorting, the distances are 0, 1, and 2, which do not satisfy the Pythagorean condition.
  • Node 3 has distances 3, 0, and 1. After sorting, the distances are 0, 1, and 3, which do not satisfy the Pythagorean condition.

No node satisfies the Pythagorean condition. Therefore, the answer is 0.

Example 3:

Input: n = 4, edges = [[0,1],[1,2],[1,3]], x = 1, y = 3, z = 0

Output: 1

Explanation:

For each node, we compute its distances to nodes x = 1, y = 3, and z = 0.

  • Node 0 has distances 1, 2, and 0. After sorting, the distances are 0, 1, and 2, which do not satisfy the Pythagorean condition.
  • Node 1 has distances 0, 1, and 1. After sorting, the distances are 0, 1, and 1. Since 02 + 12 = 12, node 1 is special.
  • Node 2 has distances 1, 2, and 2. After sorting, the distances are 1, 2, and 2, which do not satisfy the Pythagorean condition.
  • Node 3 has distances 1, 0, and 2. After sorting, the distances are 0, 1, and 2, which do not satisfy the Pythagorean condition.

Therefore, the answer is 1.

 

Constraints:

  • 4 <= n <= 105
  • edges.length == n - 1
  • edges[i] = [ui, vi]
  • 0 <= ui, vi, x, y, z <= n - 1
  • x, y, and z are pairwise distinct.
  • The input is generated such that edges represent a valid tree.

Solutions

Solution 1: BFS + Enumeration

We first construct an adjacency list \(g\) based on the edges given in the problem, where \(g[u]\) stores all nodes adjacent to node \(u\).

Next, we define a function \(\text{bfs}(i)\) to calculate the distances from node \(i\) to all other nodes. We use a queue to implement Breadth-First Search (BFS) and maintain a distance array \(\text{dist}\), where \(\text{dist}[j]\) represents the distance from node \(i\) to node \(j\). Initially, \(\text{dist}[i] = 0\), and the distances to all other nodes are set to infinity. During the BFS process, we continuously update the distance array until all reachable nodes have been traversed.

We call \(\text{bfs}(x)\), \(\text{bfs}(y)\), and \(\text{bfs}(z)\) to calculate the distances from nodes \(x\), \(y\), and \(z\) to all other nodes, obtaining three distance arrays \(d_1\), \(d_2\), and \(d_3\) respectively.

Finally, we iterate through all nodes \(u\). For each node, we retrieve its distances to \(x\), \(y\), and \(z\) as \(a = d_1[u]\), \(b = d_2[u]\), and \(c = d_3[u]\). We sort these three distances and check if they satisfy the Pythagorean theorem condition: \(a^2 + b^2 = c^2\). If the condition is met, we increment the answer count.

The time complexity is \(O(n)\), and the space complexity is \(O(n)\), where \(n\) is the number of nodes in the tree.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Solution:
    def specialNodes(
        self, n: int, edges: List[List[int]], x: int, y: int, z: int
    ) -> int:
        g = [[] for _ in range(n)]
        for u, v in edges:
            g[u].append(v)
            g[v].append(u)

        def bfs(i: int) -> List[int]:
            q = deque([i])
            dist = [inf] * n
            dist[i] = 0
            while q:
                for _ in range(len(q)):
                    u = q.popleft()
                    for v in g[u]:
                        if dist[v] > dist[u] + 1:
                            dist[v] = dist[u] + 1
                            q.append(v)
            return dist

        d1 = bfs(x)
        d2 = bfs(y)
        d3 = bfs(z)
        ans = 0
        for a, b, c in zip(d1, d2, d3):
            s = a + b + c
            a, c = min(a, b, c), max(a, b, c)
            b = s - a - c
            if a * a + b * b == c * c:
                ans += 1
        return ans
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class Solution {
    private List<Integer>[] g;
    private int n;
    private final int inf = Integer.MAX_VALUE / 2;

    public int specialNodes(int n, int[][] edges, int x, int y, int z) {
        this.n = n;
        g = new ArrayList[n];
        Arrays.setAll(g, k -> new ArrayList<>());
        for (int[] e : edges) {
            int u = e[0], v = e[1];
            g[u].add(v);
            g[v].add(u);
        }

        int[] d1 = bfs(x);
        int[] d2 = bfs(y);
        int[] d3 = bfs(z);

        int ans = 0;
        for (int i = 0; i < n; i++) {
            long[] a = new long[] {d1[i], d2[i], d3[i]};
            Arrays.sort(a);
            if (a[0] * a[0] + a[1] * a[1] == a[2] * a[2]) {
                ++ans;
            }
        }
        return ans;
    }

    private int[] bfs(int i) {
        int[] dist = new int[n];
        Arrays.fill(dist, inf);
        Deque<Integer> q = new ArrayDeque<>();
        dist[i] = 0;
        q.add(i);
        while (!q.isEmpty()) {
            for (int k = q.size(); k > 0; --k) {
                int u = q.poll();
                for (int v : g[u]) {
                    if (dist[v] > dist[u] + 1) {
                        dist[v] = dist[u] + 1;
                        q.add(v);
                    }
                }
            }
        }
        return dist;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Solution {
private:
    vector<vector<int>> g;
    int n;
    const int inf = INT_MAX / 2;

    vector<int> bfs(int i) {
        vector<int> dist(n, inf);
        queue<int> q;
        dist[i] = 0;
        q.push(i);
        while (!q.empty()) {
            for (int k = q.size(); k > 0; --k) {
                int u = q.front();
                q.pop();
                for (int v : g[u]) {
                    if (dist[v] > dist[u] + 1) {
                        dist[v] = dist[u] + 1;
                        q.push(v);
                    }
                }
            }
        }
        return dist;
    }

public:
    int specialNodes(int n, vector<vector<int>>& edges, int x, int y, int z) {
        this->n = n;
        g.assign(n, {});
        for (auto& e : edges) {
            int u = e[0], v = e[1];
            g[u].push_back(v);
            g[v].push_back(u);
        }

        vector<int> d1 = bfs(x);
        vector<int> d2 = bfs(y);
        vector<int> d3 = bfs(z);

        int ans = 0;
        for (int i = 0; i < n; ++i) {
            array<long long, 3> a = {
                (long long) d1[i],
                (long long) d2[i],
                (long long) d3[i]};
            sort(a.begin(), a.end());
            if (a[0] * a[0] + a[1] * a[1] == a[2] * a[2]) {
                ++ans;
            }
        }
        return ans;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
func specialNodes(n int, edges [][]int, x int, y int, z int) int {
    g := make([][]int, n)
    for _, e := range edges {
        u, v := e[0], e[1]
        g[u] = append(g[u], v)
        g[v] = append(g[v], u)
    }

    const inf = int(1e9)

    bfs := func(i int) []int {
        dist := make([]int, n)
        for k := 0; k < n; k++ {
            dist[k] = inf
        }
        q := make([]int, 0)
        dist[i] = 0
        q = append(q, i)
        for len(q) > 0 {
            sz := len(q)
            for ; sz > 0; sz-- {
                u := q[0]
                q = q[1:]
                for _, v := range g[u] {
                    if dist[v] > dist[u]+1 {
                        dist[v] = dist[u] + 1
                        q = append(q, v)
                    }
                }
            }
        }
        return dist
    }

    d1 := bfs(x)
    d2 := bfs(y)
    d3 := bfs(z)

    ans := 0
    for i := 0; i < n; i++ {
        a := []int{d1[i], d2[i], d3[i]}
        sort.Ints(a)
        x0, x1, x2 := int64(a[0]), int64(a[1]), int64(a[2])
        if x0*x0+x1*x1 == x2*x2 {
            ans++
        }
    }
    return ans
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
function specialNodes(n: number, edges: number[][], x: number, y: number, z: number): number {
    const g: number[][] = Array.from({ length: n }, () => []);
    for (const [u, v] of edges) {
        g[u].push(v);
        g[v].push(u);
    }

    const inf = 1e9;

    const bfs = (i: number): number[] => {
        const dist = Array(n).fill(inf);
        let q: number[] = [i];
        dist[i] = 0;
        while (q.length) {
            const nq = [];
            for (const u of q) {
                for (const v of g[u]) {
                    if (dist[v] > dist[u] + 1) {
                        dist[v] = dist[u] + 1;
                        nq.push(v);
                    }
                }
            }
            q = nq;
        }
        return dist;
    };

    const d1 = bfs(x);
    const d2 = bfs(y);
    const d3 = bfs(z);

    let ans = 0;
    for (let i = 0; i < n; i++) {
        const a = [d1[i], d2[i], d3[i]];
        a.sort((p, q) => p - q);
        if (a[0] * a[0] + a[1] * a[1] === a[2] * a[2]) {
            ans++;
        }
    }
    return ans;
}

Comments