310. Minimum Height Trees (Medium) - TengnanYao/daily_leetcode GitHub Wiki

class Solution:
    def findMinHeightTrees(self, n: int, edges: List[List[int]]) -> List[int]:
        h = defaultdict(set)
        for a, b in edges:
            h[a].add(b)
            h[b].add(a)
        result = set(range(n))
        leaves = set(leaf for leaf in h if len(h[leaf]) == 1)
        while len(result) > 2:
            result -= leaves
            temp = set()
            for leaf in leaves:
                for val in h[leaf]:
                    h[val].remove(leaf)
                    if len(h[val]) == 1:
                        temp.add(val)
            leaves = temp
        return result