Minimum Height Trees
Difficulty: Medium Source: NeetCode
Problem
A tree is an undirected graph in which any two vertices are connected by exactly one path. In other words, any connected graph without simple cycles is a tree.
Given a tree of
nnodes labeled from0ton - 1, and an array ofn - 1edges whereedges[i] = [ai, bi]indicates that there is an undirected edge between the two nodesaiandbiin the tree, you can choose any node of the tree as the root. When you select a nodexas the root, the resulting tree has heighth. Among all possible rooted trees, those with minimum height are called minimum height trees (MHTs).Return a list of all MHTs’ root labels.
Example 1: Input:
n = 4, edges = [[1,0],[1,2],[1,3]]Output:[1]Example 2: Input:
n = 6, edges = [[3,0],[3,1],[3,2],[3,4],[5,4]]Output:[3,4]Constraints:
1 <= n <= 2 * 10^4edges.length == n - 10 <= ai, bi < n- All the pairs
(ai, bi)are distinct
Prerequisites
Before attempting this problem, you should be comfortable with:
- Tree properties — the center(s) of a tree minimize the height
- Topological peeling (leaf removal) — iteratively removing leaf nodes finds the center
- BFS — used to process nodes level by level during peeling
1. Brute Force (BFS from every node)
Intuition
For each node, run BFS to compute the height of the tree rooted at that node. Track the minimum height and collect all nodes that achieve it. This is correct but O(N²) — fine for small trees but too slow for the constraints.
Algorithm
- For each node from 0 to n-1, run BFS to get the tree height rooted there.
- Collect all nodes with minimum height.
- Return them.
Solution
from collections import deque
def findMinHeightTrees(n: int, edges: list[list[int]]) -> list[int]:
if n == 1:
return [0]
graph = [[] for _ in range(n)]
for a, b in edges:
graph[a].append(b)
graph[b].append(a)
def bfs_height(root):
visited = {root}
queue = deque([root])
height = 0
while queue:
for _ in range(len(queue)):
node = queue.popleft()
for neighbor in graph[node]:
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
height += 1
return height - 1
min_h = float('inf')
result = []
for node in range(n):
h = bfs_height(node)
if h < min_h:
min_h = h
result = [node]
elif h == min_h:
result.append(node)
return result
print(findMinHeightTrees(4, [[1,0],[1,2],[1,3]])) # [1]
print(findMinHeightTrees(6, [[3,0],[3,1],[3,2],[3,4],[5,4]])) # [3, 4]
print(findMinHeightTrees(1, [])) # [0]
Complexity
- Time:
O(N²)— BFS from each of N nodes - Space:
O(N)
2. Topological Peeling (Leaf Trimming)
Intuition
The key insight: the root(s) of the minimum height tree are always the center node(s) of the tree — there are at most 2 such centers. We find them by repeatedly “peeling” leaf nodes (degree 1) from the outside inward. This is like topological sort for trees. We stop when 1 or 2 nodes remain — those are the answer.
Think of it like finding the “middle” of a tree the same way you find the middle of a linked list — shrink from both ends simultaneously.
Algorithm
- Build adjacency list, compute degree for each node.
- Initialize a queue with all leaf nodes (degree == 1).
- Remaining = n. While remaining > 2:
- Process all current leaves, decrement remaining count.
- For each leaf’s neighbor, decrement their degree. If degree reaches 1, they’re the new leaves — enqueue.
- Return whatever nodes remain.
Solution
from collections import deque
def findMinHeightTrees(n: int, edges: list[list[int]]) -> list[int]:
if n == 1:
return [0]
if n == 2:
return [0, 1]
graph = [set() for _ in range(n)]
for a, b in edges:
graph[a].add(b)
graph[b].add(a)
# Start with all leaf nodes
leaves = deque(node for node in range(n) if len(graph[node]) == 1)
remaining = n
while remaining > 2:
remaining -= len(leaves)
new_leaves = deque()
while leaves:
leaf = leaves.popleft()
# There's exactly one neighbor for a leaf
neighbor = next(iter(graph[leaf]))
graph[neighbor].discard(leaf)
if len(graph[neighbor]) == 1:
new_leaves.append(neighbor)
leaves = new_leaves
return list(leaves)
print(findMinHeightTrees(4, [[1,0],[1,2],[1,3]])) # [1]
print(findMinHeightTrees(6, [[3,0],[3,1],[3,2],[3,4],[5,4]])) # [3, 4]
print(findMinHeightTrees(1, [])) # [0]
print(findMinHeightTrees(2, [[0,1]])) # [0, 1]
print(findMinHeightTrees(7, [[0,1],[1,2],[1,3],[2,4],[3,5],[4,6]])) # [1, 2]
Complexity
- Time:
O(N)— each node and edge processed at most once - Space:
O(N)
Common Pitfalls
Edge case n=1. A single node with no edges has height 0 and is trivially the answer. Handle it before building the adjacency list.
Stopping condition: remaining > 2. There can be at most 2 center nodes (a tree’s “diameter” has at most 2 midpoints). Stop when 1 or 2 nodes remain, not when the queue is empty — the queue might be empty but remaining could still be > 2 for pathological inputs (though this shouldn’t happen with valid input).
Using a set for adjacency. Using set instead of list for the adjacency lets you do graph[neighbor].discard(leaf) efficiently. With a list, removal is O(degree).