LC 0333 [M] Largest BST Subtree - ALawliet/algorithms GitHub Wiki

https://www.youtube.com/watch?v=t3kHAoRT5iQ O(n)

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def largestBSTSubtree(self, root: TreeNode) -> int:
        self.res = 0
        
        def dfs(root):
            if not root:
                return [True, float('inf'), float('-inf'), 0] # isBST, min, max, size
            
            leftIsBST, leftMin, leftMax, leftSize = dfs(root.left)
            rightIsBST, rightMin, rightMax, rightSize = dfs(root.right)
            
            if leftIsBST and rightIsBST and leftMax < root.val < rightMin:
                self.res = max(self.res, leftSize + rightSize + 1)
                return [True, min(leftMin, root.val), max(rightMax, root.val), leftSize + rightSize + 1]
            else:
                return [False, float('inf'), float('-inf'), 0]
            
        dfs(root)
        return self.res

This is definitely not O(n) solution. Please mark the complexity in the title, in order to make people not confused. For every node, you did isValid() and count() for it. This means for every node, you traverse all the nodes belong to this node, which is O(n). That is to say, you did O(n) for each node. So, when you finished the calculation, you did O(n) * n = O(n^2).

O(n^2)

class Solution:
    def largestBSTSubtree(self, root: TreeNode) -> int:    
        if not root: return 0
        
        if self.isBST(root, -inf, inf):
            return self.count(root)
        
        return max(self.largestBSTSubtree(root.left), self.largestBSTSubtree(root.right))
    
    def isBST(self, node, _min, _max):
        if not node: return True

        if not _min < node.val < _max: return False

        return self.isBST(node.left, _min, node.val) and self.isBST(node.right, node.val, _max)
    
    def count(self, node):
        if not node: return 0

        return 1 + self.count(node.left) + self.count(node.right)

O(n)

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def largestBSTSubtree(self, root: TreeNode) -> int:
        def dfs(node):
            if not node:
                return 0, inf, -inf
            
            # postorder
            l_size, l_min, l_max = dfs(node.left)
            r_size, r_min, r_max = dfs(node.right)
            
            if l_max < node.val < r_min:
                # subtrees is valid BST, we can keep counting the l and r subtree (required both subtrees are also BST)
                return l_size + r_size + 1, min(l_min, node.val), max(r_max, node.val)
            else:
                # subtrees not valid BST, just return parent the biggest from l or r subtrees
                return max(l_size, r_size), -inf, inf
        
        return dfs(root)[0]
class Solution(object):
    def largestBSTSubtree(self, root):
        """
        :type root: TreeNode
        :rtype: int
        """
        self.maxsize=0
        def dfs(node):
            if not node: return 0,float('inf'),float('-inf') # count,mmin,mmax
            lcount,lmin,lmax=dfs(node.left)
            rcount,rmin,rmax=dfs(node.right)
            
            if lcount>=0 and rcount>=0 and node.val>lmax and node.val<rmin:
                count=1+lcount+rcount
                self.maxsize=max(self.maxsize,count)
                #mmin=lmin if lmin<float('inf') else node.val
                #mmax=rmax if rmax>float('-inf') else node.val
                mmin=lmin if lcount else node.val
                mmax=rmax if rcount else node.val
                return count, mmin,mmax
            else:
                return -1,float('-inf'),float('inf')
                
        dfs(root)
        return self.maxsize
class Solution(object):
    def largestBSTSubtree(self, root):
        """
        :type root: TreeNode
        :rtype: int
        """
        self.rst = 0
        self.dfs(root)
        return self.rst
        
    def dfs(self, root):
        
        if not root:
            return 0, float("inf"), float("-inf")
        
        left = self.dfs(root.left)
        right = self.dfs(root.right)
        
        if left[2] < root.val < right[1]:
            n = left[0] + right[0] + 1
        else:
            n = float("-inf")
            
        self.rst = max(self.rst, n)
        return n, min(root.val, left[1]), max(root.val, right[2])