LC 0250 [M] Count Univalue Subtrees - ALawliet/algorithms GitHub Wiki

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def countUnivalSubtrees(self, root: TreeNode) -> int:
        def dfs(node):
            if not node.left and not node.right:
                self.global_ans += 1
                return True
            
            amIunival = True
            if node.left:
                bL = dfs(node.left)
                if not bL or node.val != node.left.val:
                    amIunival = False
            if node.right:
                bR = dfs(node.right)
                if not bR or node.val != node.right.val:
                    amIunival = False
                    
            if amIunival:
                self.global_ans += 1
            
            return amIunival
        
        if not root: return 0
        self.global_ans = 0
        dfs(root)
        return self.global_ans