Skip to content

3879. Maximum Distinct Path Sum in a Binary Tree πŸ”’

Description

You are given the root of a binary tree, where each node contains an integer value.

A valid path in the tree is a sequence of connected nodes such that:

  • The path can start and end at any node in the tree.
  • The path does not need to pass through the root.
  • All node values along the path are distinct.

Return an integer denoting the maximum possible sum of node values among all valid paths.

Β 

Example 1:

Input: root = [2,2,1]

Output: 3

Explanation:

  • The path 2 β†’ 2 is invalid because the value 2 is not distinct.
  • The maximum-sum valid path is 2 β†’ 1, with a sum = 2 + 1 = 3.

Example 2:

Input: root = [1,-2,5,null,null,3,5]

Output: 9

Explanation:

  • The path 3 β†’ 5 β†’ 5 is invalid due to duplicate value 5.
  • The maximum-sum valid path is 1 β†’ 5 β†’ 3, with a sum = 1 + 5 + 3 = 9.

Example 3:

​​​​​​​

Input: root = [4,6,6,null,null,null,9]

Output: 19

Explanation:

  • The path 6 β†’ 4 β†’ 6 β†’ 9 is invalid because the value 6 appears more than once.
  • The maximum-sum valid path is 4 β†’ 6 β†’ 9, with a sum = 4 + 6 + 9 = 19.

Β 

Constraints:

  • The number of nodes in the tree is in the range [1, 1000].
  • -1000 <= Node.val <= 1000​​​​​​​

Solutions

Solution 1: DFS + Hash Table

We can treat the tree as an undirected graph, using a hash table \(g\) to store the adjacent nodes of each node, where \(g[node]\) contains the parent node, left child node, and right child node of \(node\).

We use depth-first search to traverse the tree and build the hash table \(g\). For each node, we add its parent node, left child node, and right child node to \(g[node]\).

Next, we use another depth-first search to compute the maximum path sum starting from each node. During this process, we use a hash set \(vis\) to record the node values already visited on the current path, ensuring all node values along the path are distinct. For each node, we first check whether it is already in \(vis\); if so, we return \(0\). Otherwise, we add the node value to \(vis\) and compute the path sum starting from that node. We traverse the adjacent nodes in \(g[node]\), recursively compute the path sum starting from each adjacent node, and update the current best. Finally, we remove the current node value from \(vis\) and return the current node value plus the best path sum.

We perform the above computation for every node in the tree and track the maximum path sum. The final answer is the maximum path sum.

The time complexity is \(O(n^2)\), and the space complexity is \(O(n)\), where \(n\) is the number of nodes in the tree.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def maxSum(self, root: Optional[TreeNode]) -> int:
        def dfs(node, p):
            if node is None:
                return
            g[node].append(p)
            g[node].append(node.left)
            g[node].append(node.right)
            dfs(node.left, node)
            dfs(node.right, node)

        def dfs2(node):
            if node is None or node.val in vis:
                return 0
            vis.add(node.val)
            res = node.val
            best = 0
            for nxt in g[node]:
                best = max(best, dfs2(nxt))
            vis.remove(node.val)
            res += best
            return res

        g = defaultdict(list)
        dfs(root, None)
        vis = set()
        ans = -inf
        for node in g:
            ans = max(ans, dfs2(node))
            vis.clear()
        return ans
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode() {}
 *     TreeNode(int val) { this.val = val; }
 *     TreeNode(int val, TreeNode left, TreeNode right) {
 *         this.val = val;
 *         this.left = left;
 *         this.right = right;
 *     }
 * }
 */
class Solution {
    Map<TreeNode, List<TreeNode>> g = new HashMap<>();
    Set<Integer> vis = new HashSet<>();

    public int maxSum(TreeNode root) {
        dfs(root, null);

        int ans = Integer.MIN_VALUE;
        for (TreeNode node : g.keySet()) {
            ans = Math.max(ans, dfs2(node));
            vis.clear();
        }
        return ans;
    }

    private void dfs(TreeNode node, TreeNode p) {
        if (node == null) {
            return;
        }
        g.computeIfAbsent(node, k -> new ArrayList<>());
        g.get(node).add(p);
        g.get(node).add(node.left);
        g.get(node).add(node.right);

        dfs(node.left, node);
        dfs(node.right, node);
    }

    private int dfs2(TreeNode node) {
        if (node == null || vis.contains(node.val)) {
            return 0;
        }
        vis.add(node.val);
        int res = node.val;
        int best = 0;
        for (TreeNode nxt : g.getOrDefault(node, Collections.emptyList())) {
            best = Math.max(best, dfs2(nxt));
        }
        vis.remove(node.val);
        return res + best;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
 * };
 */
class Solution {
public:
    int maxSum(TreeNode* root) {
        unordered_map<TreeNode*, vector<TreeNode*>> g;
        unordered_set<int> vis;

        auto dfs = [&](this auto&& dfs, TreeNode* node, TreeNode* p) -> void {
            if (!node) return;
            g[node].push_back(p);
            g[node].push_back(node->left);
            g[node].push_back(node->right);
            dfs(node->left, node);
            dfs(node->right, node);
        };

        auto dfs2 = [&](this auto&& dfs2, TreeNode* node) -> int {
            if (!node || vis.count(node->val)) return 0;
            vis.insert(node->val);
            int res = node->val;
            int best = 0;
            for (auto nxt : g[node]) {
                best = max(best, dfs2(nxt));
            }
            vis.erase(node->val);
            return res + best;
        };

        dfs(root, nullptr);

        int ans = INT_MIN;
        for (auto& [node, _] : g) {
            ans = max(ans, dfs2(node));
            vis.clear();
        }
        return ans;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
/**
 * Definition for a binary tree node.
 * type TreeNode struct {
 *     Val int
 *     Left *TreeNode
 *     Right *TreeNode
 * }
 */
func maxSum(root *TreeNode) int {
    g := map[*TreeNode][]*TreeNode{}

    var dfs func(node, p *TreeNode)
    dfs = func(node, p *TreeNode) {
        if node == nil {
            return
        }
        g[node] = append(g[node], p, node.Left, node.Right)
        dfs(node.Left, node)
        dfs(node.Right, node)
    }

    vis := map[int]bool{}

    var dfs2 func(node *TreeNode) int
    dfs2 = func(node *TreeNode) int {
        if node == nil || vis[node.Val] {
            return 0
        }
        vis[node.Val] = true
        res := node.Val
        best := 0
        for _, nxt := range g[node] {
            if v := dfs2(nxt); v > best {
                best = v
            }
        }
        vis[node.Val] = false
        return res + best
    }

    dfs(root, nil)

    ans := math.MinInt
    for node := range g {
        ans = max(ans, dfs2(node))
        clear(vis)
    }
    return ans
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
/**
 * Definition for a binary tree node.
 * class TreeNode {
 *     val: number
 *     left: TreeNode | null
 *     right: TreeNode | null
 *     constructor(val?: number, left?: TreeNode | null, right?: TreeNode | null) {
 *         this.val = (val===undefined ? 0 : val)
 *         this.left = (left===undefined ? null : left)
 *         this.right = (right===undefined ? null : right)
 *     }
 * }
 */
function maxSum(root: TreeNode | null): number {
    const g = new Map<TreeNode, (TreeNode | null)[]>();

    function dfs(node: TreeNode | null, p: TreeNode | null): void {
        if (!node) return;
        if (!g.has(node)) g.set(node, []);
        g.get(node)!.push(p, node.left, node.right);
        dfs(node.left, node);
        dfs(node.right, node);
    }

    const vis = new Set<number>();

    function dfs2(node: TreeNode | null): number {
        if (!node || vis.has(node.val)) return 0;
        vis.add(node.val);
        let res = node.val;
        let best = 0;
        for (const nxt of g.get(node) || []) {
            best = Math.max(best, dfs2(nxt));
        }
        vis.delete(node.val);
        return res + best;
    }

    dfs(root, null);

    let ans = -Infinity;
    for (const node of g.keys()) {
        ans = Math.max(ans, dfs2(node));
        vis.clear();
    }
    return ans;
}

Comments