LC 0310 [M] Minimum Height Trees - ALawliet/algorithms GitHub Wiki

https://www.youtube.com/watch?v=OsvbLAaRmu8&ab_channel=HappyCoding

leaf removal topological sort

class Solution:
    def findMinHeightTrees(self, n: int, edges: List[List[int]]) -> List[int]:
        if n == 1:
            return [0]
        
        G = defaultdict(list)
        
        for u, v in edges:
            G[u].append(v)
            G[v].append(u)
            
        
        # Q for topological sort
        leaves = [node for node in G.keys() if len(G[node]) == 1]
        
        while n > 2: # there can be only 2 nodes at most in the final result
            n -= len(leaves)
            new_leaves = set()
            
            for leaf in leaves:
                neighbor = G[leaf].pop() # (neighbor <- leaf) [0] or .pop() since the leaf can only have 1 neighbor and we don't revisit it anyway
                G[neighbor].remove(leaf) # (neighbor -> leaf) because we at least remove the neighbor to leaf
                
                if len(G[neighbor]) == 1: # neighbor became a new leaf itself after its leaves were removed
                    new_leaves.add(neighbor)
                    
            leaves = new_leaves
            
        return leaves