124. Binary Tree Maximum Path Sum - cocoder39/coco39_LC GitHub Wiki

124. Binary Tree Maximum Path Sum

notes 2024

class Solution:
    def maxPathSum(self, root: Optional[TreeNode]) -> int:
        self.max_path_sum = -float('inf')
        
        def dfs(node):
            if not node:
                return 0
            
            left = dfs(node.left)
            right = dfs(node.right)
            path_sum = max(0, left) + max(0, right) + node.val
            self.max_path_sum = max(self.max_path_sum, path_sum)

            return max(0, left, right) + node.val
        
        dfs(root)
        return self.max_path_sum

======================

Notes 2022


class PathSum:
    def __init__(self, downPathSum: int, crossPathSum: int, maxCrossPathSum=None):
        self.downPathSum = downPathSum
        self.crossPathSum = crossPathSum
        self.maxCrossPathSum = maxCrossPathSum
        

class Solution:
    def maxPathSum(self, root: Optional[TreeNode]) -> int:
        pathSum = self.helper(root)
        return pathSum.maxCrossPathSum
    
    
    def helper(self, root: TreeNode) -> 'PathSum':
        if not root:
            return PathSum(0, 0)
        
        leftPathSum = self.helper(root.left)
        rightPathSum = self.helper(root.right)
        
        leftDownPathSum = max(0, leftPathSum.downPathSum)
        rightDownPathSum = max(0, rightPathSum.downPathSum)
        
        downPathSum = root.val + max(leftDownPathSum, rightDownPathSum)
        crossPathSum = root.val + leftDownPathSum + rightDownPathSum
        
        leftMaxCrossPathSum = -float('inf') if leftPathSum.maxCrossPathSum is None else leftPathSum.maxCrossPathSum
        rightMaxCrossPathSum = -float('inf') if rightPathSum.maxCrossPathSum is None else rightPathSum.maxCrossPathSum
        maxCrossPathSum = max(crossPathSum, leftMaxCrossPathSum, rightMaxCrossPathSum)
        
        return PathSum(downPathSum, crossPathSum, maxCrossPathSum)

===========================================================

       1
      / \
     2   3

according to definition, there are 2 kinds of paths here. On one path, each node has only one child node on the path (eg 1 -> 3), called down path. On the other path, one node has both left and right children on the path, and that node is top-most of the path (eg. 2 -> 1 -> 3), called cross path.

  1. without considering the 2nd situation, the problem can be solved through
//return max sum of a dowm path
    int helper(TreeNode* root) {
        if (! root) {
            return 0;
        }
        int left = helper(root->left);
        int right = helper(root->right);
        left = max(left, 0);
        right = max(right, 0);
        return root->val + max(left, right);
    }

2 To handle case 2, we need visit each node and view it as the node to be crossed. Then the max path sum cross this node is

int helper(TreeNode* root) {
        if (! root) {
            return 0;
        }
        int left = helper(root->left);
        int right = helper(root->right);
        left = max(left, 0);
        right = max(right, 0);

        res = root->val + left + right;
        return root->val + max(left, right);
    }

hence we get code below through combing above two. dfs takes O(n) time and O(height) call stack space

class Solution {
public:
    int maxPathSum(TreeNode* root) {
        int res = INT_MIN;
        helper(root, res);
        return res;
    }
private:
    //return max sum of a dowm path
    int helper(TreeNode* root, int& res) {
        if (! root) {
            return 0;
        }
        int left = helper(root->left, res);
        int right = helper(root->right, res);
        left = max(left, 0);
        right = max(right, 0);
        
        /* sum of path crossing root */
        res = max(res, root->val + left + right);

        return root->val + max(left, right);
    }
};