Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Minimum Height Trees

Difficulty: Medium Source: NeetCode

Problem

A tree is an undirected graph in which any two vertices are connected by exactly one path. In other words, any connected graph without simple cycles is a tree.

Given a tree of n nodes labeled from 0 to n - 1, and an array of n - 1 edges where edges[i] = [ai, bi] indicates that there is an undirected edge between the two nodes ai and bi in the tree, you can choose any node of the tree as the root. When you select a node x as the root, the resulting tree has height h. Among all possible rooted trees, those with minimum height are called minimum height trees (MHTs).

Return a list of all MHTs’ root labels.

Example 1: Input: n = 4, edges = [[1,0],[1,2],[1,3]] Output: [1]

Example 2: Input: n = 6, edges = [[3,0],[3,1],[3,2],[3,4],[5,4]] Output: [3,4]

Constraints:

  • 1 <= n <= 2 * 10^4
  • edges.length == n - 1
  • 0 <= ai, bi < n
  • All the pairs (ai, bi) are distinct

Prerequisites

Before attempting this problem, you should be comfortable with:

  • Tree properties — the center(s) of a tree minimize the height
  • Topological peeling (leaf removal) — iteratively removing leaf nodes finds the center
  • BFS — used to process nodes level by level during peeling

1. Brute Force (BFS from every node)

Intuition

For each node, run BFS to compute the height of the tree rooted at that node. Track the minimum height and collect all nodes that achieve it. This is correct but O(N²) — fine for small trees but too slow for the constraints.

Algorithm

  1. For each node from 0 to n-1, run BFS to get the tree height rooted there.
  2. Collect all nodes with minimum height.
  3. Return them.

Solution

from collections import deque

def findMinHeightTrees(n: int, edges: list[list[int]]) -> list[int]:
    if n == 1:
        return [0]

    graph = [[] for _ in range(n)]
    for a, b in edges:
        graph[a].append(b)
        graph[b].append(a)

    def bfs_height(root):
        visited = {root}
        queue = deque([root])
        height = 0
        while queue:
            for _ in range(len(queue)):
                node = queue.popleft()
                for neighbor in graph[node]:
                    if neighbor not in visited:
                        visited.add(neighbor)
                        queue.append(neighbor)
            height += 1
        return height - 1

    min_h = float('inf')
    result = []
    for node in range(n):
        h = bfs_height(node)
        if h < min_h:
            min_h = h
            result = [node]
        elif h == min_h:
            result.append(node)
    return result


print(findMinHeightTrees(4, [[1,0],[1,2],[1,3]]))        # [1]
print(findMinHeightTrees(6, [[3,0],[3,1],[3,2],[3,4],[5,4]]))  # [3, 4]
print(findMinHeightTrees(1, []))                           # [0]

Complexity

  • Time: O(N²) — BFS from each of N nodes
  • Space: O(N)

2. Topological Peeling (Leaf Trimming)

Intuition

The key insight: the root(s) of the minimum height tree are always the center node(s) of the tree — there are at most 2 such centers. We find them by repeatedly “peeling” leaf nodes (degree 1) from the outside inward. This is like topological sort for trees. We stop when 1 or 2 nodes remain — those are the answer.

Think of it like finding the “middle” of a tree the same way you find the middle of a linked list — shrink from both ends simultaneously.

Algorithm

  1. Build adjacency list, compute degree for each node.
  2. Initialize a queue with all leaf nodes (degree == 1).
  3. Remaining = n. While remaining > 2:
    • Process all current leaves, decrement remaining count.
    • For each leaf’s neighbor, decrement their degree. If degree reaches 1, they’re the new leaves — enqueue.
  4. Return whatever nodes remain.

Solution

from collections import deque

def findMinHeightTrees(n: int, edges: list[list[int]]) -> list[int]:
    if n == 1:
        return [0]
    if n == 2:
        return [0, 1]

    graph = [set() for _ in range(n)]
    for a, b in edges:
        graph[a].add(b)
        graph[b].add(a)

    # Start with all leaf nodes
    leaves = deque(node for node in range(n) if len(graph[node]) == 1)
    remaining = n

    while remaining > 2:
        remaining -= len(leaves)
        new_leaves = deque()
        while leaves:
            leaf = leaves.popleft()
            # There's exactly one neighbor for a leaf
            neighbor = next(iter(graph[leaf]))
            graph[neighbor].discard(leaf)
            if len(graph[neighbor]) == 1:
                new_leaves.append(neighbor)
        leaves = new_leaves

    return list(leaves)


print(findMinHeightTrees(4, [[1,0],[1,2],[1,3]]))             # [1]
print(findMinHeightTrees(6, [[3,0],[3,1],[3,2],[3,4],[5,4]])) # [3, 4]
print(findMinHeightTrees(1, []))                               # [0]
print(findMinHeightTrees(2, [[0,1]]))                          # [0, 1]
print(findMinHeightTrees(7, [[0,1],[1,2],[1,3],[2,4],[3,5],[4,6]]))  # [1, 2]

Complexity

  • Time: O(N) — each node and edge processed at most once
  • Space: O(N)

Common Pitfalls

Edge case n=1. A single node with no edges has height 0 and is trivially the answer. Handle it before building the adjacency list.

Stopping condition: remaining > 2. There can be at most 2 center nodes (a tree’s “diameter” has at most 2 midpoints). Stop when 1 or 2 nodes remain, not when the queue is empty — the queue might be empty but remaining could still be > 2 for pathological inputs (though this shouldn’t happen with valid input).

Using a set for adjacency. Using set instead of list for the adjacency lets you do graph[neighbor].discard(leaf) efficiently. With a list, removal is O(degree).