跳转至

3879. Maximum Distinct Path Sum in a Binary Tree 🔒

题目描述

You are given the root of a binary tree, where each node contains an integer value.

A valid path in the tree is a sequence of connected nodes such that:

  • The path can start and end at any node in the tree.
  • The path does not need to pass through the root.
  • All node values along the path are distinct.

Return an integer denoting the maximum possible sum of node values among all valid paths.

 

Example 1:

Input: root = [2,2,1]

Output: 3

Explanation:

  • The path 2 → 2 is invalid because the value 2 is not distinct.
  • The maximum-sum valid path is 2 → 1, with a sum = 2 + 1 = 3.

Example 2:

Input: root = [1,-2,5,null,null,3,5]

Output: 9

Explanation:

  • The path 3 → 5 → 5 is invalid due to duplicate value 5.
  • The maximum-sum valid path is 1 → 5 → 3, with a sum = 1 + 5 + 3 = 9.

Example 3:

​​​​​​​

Input: root = [4,6,6,null,null,null,9]

Output: 19

Explanation:

  • The path 6 → 4 → 6 → 9 is invalid because the value 6 appears more than once.
  • The maximum-sum valid path is 4 → 6 → 9, with a sum = 4 + 6 + 9 = 19.

 

Constraints:

  • The number of nodes in the tree is in the range [1, 1000].
  • -1000 <= Node.val <= 1000​​​​​​​

解法

方法一:DFS + 哈希表

我们可以将树看成一个无向图,使用一个哈希表 \(g\) 来存储每个节点的相邻节点,其中 \(g[node]\) 包含节点 \(node\) 的父节点、左子节点和右子节点。

我们使用深度优先搜索来遍历树,并构建哈希表 \(g\)。对于每个节点,我们将其父节点、左子节点和右子节点添加到 \(g[node]\) 中。

接下来,我们使用另一个深度优先搜索来计算以每个节点为起点的最大路径和。在这个过程中,我们使用一个哈希集合 \(vis\) 来记录当前路径上已经访问过的节点值,以确保路径上的节点值都是不同的。对于每个节点,我们首先检查它是否已经在 \(vis\) 中,如果是,则返回 0。否则,我们将节点值添加到 \(vis\) 中,并计算以该节点为起点的路径和。我们遍历 \(g[node]\) 中的相邻节点,递归地计算以相邻节点为起点的路径和,并更新当前节点的路径和。最后,我们将当前节点值从 \(vis\) 中移除,并返回当前节点值加上最佳路径和。

我们对树中的每个节点执行上述计算,并记录最大路径和。最终返回最大路径和作为答案。

时间复杂度 \(O(n^2)\),空间复杂度 \(O(n)\)。其中 \(n\) 是树中的节点数。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def maxSum(self, root: Optional[TreeNode]) -> int:
        def dfs(node, p):
            if node is None:
                return
            g[node].append(p)
            g[node].append(node.left)
            g[node].append(node.right)
            dfs(node.left, node)
            dfs(node.right, node)

        def dfs2(node):
            if node is None or node.val in vis:
                return 0
            vis.add(node.val)
            res = node.val
            best = 0
            for nxt in g[node]:
                best = max(best, dfs2(nxt))
            vis.remove(node.val)
            res += best
            return res

        g = defaultdict(list)
        dfs(root, None)
        vis = set()
        ans = -inf
        for node in g:
            ans = max(ans, dfs2(node))
            vis.clear()
        return ans
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode() {}
 *     TreeNode(int val) { this.val = val; }
 *     TreeNode(int val, TreeNode left, TreeNode right) {
 *         this.val = val;
 *         this.left = left;
 *         this.right = right;
 *     }
 * }
 */
class Solution {
    Map<TreeNode, List<TreeNode>> g = new HashMap<>();
    Set<Integer> vis = new HashSet<>();

    public int maxSum(TreeNode root) {
        dfs(root, null);

        int ans = Integer.MIN_VALUE;
        for (TreeNode node : g.keySet()) {
            ans = Math.max(ans, dfs2(node));
            vis.clear();
        }
        return ans;
    }

    private void dfs(TreeNode node, TreeNode p) {
        if (node == null) {
            return;
        }
        g.computeIfAbsent(node, k -> new ArrayList<>());
        g.get(node).add(p);
        g.get(node).add(node.left);
        g.get(node).add(node.right);

        dfs(node.left, node);
        dfs(node.right, node);
    }

    private int dfs2(TreeNode node) {
        if (node == null || vis.contains(node.val)) {
            return 0;
        }
        vis.add(node.val);
        int res = node.val;
        int best = 0;
        for (TreeNode nxt : g.getOrDefault(node, Collections.emptyList())) {
            best = Math.max(best, dfs2(nxt));
        }
        vis.remove(node.val);
        return res + best;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
 * };
 */
class Solution {
public:
    int maxSum(TreeNode* root) {
        unordered_map<TreeNode*, vector<TreeNode*>> g;
        unordered_set<int> vis;

        auto dfs = [&](this auto&& dfs, TreeNode* node, TreeNode* p) -> void {
            if (!node) return;
            g[node].push_back(p);
            g[node].push_back(node->left);
            g[node].push_back(node->right);
            dfs(node->left, node);
            dfs(node->right, node);
        };

        auto dfs2 = [&](this auto&& dfs2, TreeNode* node) -> int {
            if (!node || vis.count(node->val)) return 0;
            vis.insert(node->val);
            int res = node->val;
            int best = 0;
            for (auto nxt : g[node]) {
                best = max(best, dfs2(nxt));
            }
            vis.erase(node->val);
            return res + best;
        };

        dfs(root, nullptr);

        int ans = INT_MIN;
        for (auto& [node, _] : g) {
            ans = max(ans, dfs2(node));
            vis.clear();
        }
        return ans;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
/**
 * Definition for a binary tree node.
 * type TreeNode struct {
 *     Val int
 *     Left *TreeNode
 *     Right *TreeNode
 * }
 */
func maxSum(root *TreeNode) int {
    g := map[*TreeNode][]*TreeNode{}

    var dfs func(node, p *TreeNode)
    dfs = func(node, p *TreeNode) {
        if node == nil {
            return
        }
        g[node] = append(g[node], p, node.Left, node.Right)
        dfs(node.Left, node)
        dfs(node.Right, node)
    }

    vis := map[int]bool{}

    var dfs2 func(node *TreeNode) int
    dfs2 = func(node *TreeNode) int {
        if node == nil || vis[node.Val] {
            return 0
        }
        vis[node.Val] = true
        res := node.Val
        best := 0
        for _, nxt := range g[node] {
            if v := dfs2(nxt); v > best {
                best = v
            }
        }
        vis[node.Val] = false
        return res + best
    }

    dfs(root, nil)

    ans := math.MinInt
    for node := range g {
        ans = max(ans, dfs2(node))
        clear(vis)
    }
    return ans
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
/**
 * Definition for a binary tree node.
 * class TreeNode {
 *     val: number
 *     left: TreeNode | null
 *     right: TreeNode | null
 *     constructor(val?: number, left?: TreeNode | null, right?: TreeNode | null) {
 *         this.val = (val===undefined ? 0 : val)
 *         this.left = (left===undefined ? null : left)
 *         this.right = (right===undefined ? null : right)
 *     }
 * }
 */
function maxSum(root: TreeNode | null): number {
    const g = new Map<TreeNode, (TreeNode | null)[]>();

    function dfs(node: TreeNode | null, p: TreeNode | null): void {
        if (!node) return;
        if (!g.has(node)) g.set(node, []);
        g.get(node)!.push(p, node.left, node.right);
        dfs(node.left, node);
        dfs(node.right, node);
    }

    const vis = new Set<number>();

    function dfs2(node: TreeNode | null): number {
        if (!node || vis.has(node.val)) return 0;
        vis.add(node.val);
        let res = node.val;
        let best = 0;
        for (const nxt of g.get(node) || []) {
            best = Math.max(best, dfs2(nxt));
        }
        vis.delete(node.val);
        return res + best;
    }

    dfs(root, null);

    let ans = -Infinity;
    for (const node of g.keys()) {
        ans = Math.max(ans, dfs2(node));
        vis.clear();
    }
    return ans;
}

评论