310. Minimum Height Trees (Medium) - TengnanYao/daily_leetcode GitHub Wiki
class Solution:
def findMinHeightTrees(self, n: int, edges: List[List[int]]) -> List[int]:
h = defaultdict(set)
for a, b in edges:
h[a].add(b)
h[b].add(a)
result = set(range(n))
leaves = set(leaf for leaf in h if len(h[leaf]) == 1)
while len(result) > 2:
result -= leaves
temp = set()
for leaf in leaves:
for val in h[leaf]:
h[val].remove(leaf)
if len(h[val]) == 1:
temp.add(val)
leaves = temp
return result