Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Kth Smallest Element in a BST

Difficulty: Medium Source: NeetCode

Problem

Given the root of a binary search tree, and an integer k, return the kth smallest value (1-indexed) of all the values of the nodes in the tree.

Example 1: Input: root = [3,1,4,null,2], k = 1 Output: 1

Example 2: Input: root = [5,3,6,2,4,null,null,1], k = 3 Output: 3

Constraints:

  • The number of nodes in the tree is n
  • 1 <= k <= n <= 10^4
  • 0 <= Node.val <= 10^4

Prerequisites

Before attempting this problem, you should be comfortable with:

  • Inorder Traversal — Inorder of a BST gives values in sorted ascending order
  • BST Property — Left subtree values are always smaller than root
  • Early Exit — Stopping traversal once we’ve found the kth element

1. Inorder Traversal (Collect All)

Intuition

Inorder traversal of a BST visits nodes in ascending order. So the kth element in an inorder traversal is the kth smallest element. The simplest version collects all values in order, then returns index k-1.

Algorithm

  1. Do inorder traversal, collect all values into a list
  2. Return values[k-1]

Solution

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def kth_smallest_naive(root, k):
    values = []

    def inorder(node):
        if not node:
            return
        inorder(node.left)
        values.append(node.val)
        inorder(node.right)

    inorder(root)
    return values[k - 1]

# --- helpers ---
def build_tree(values):
    if not values:
        return None
    root = TreeNode(values[0])
    queue = [root]
    i = 1
    while queue and i < len(values):
        node = queue.pop(0)
        if i < len(values) and values[i] is not None:
            node.left = TreeNode(values[i])
            queue.append(node.left)
        i += 1
        if i < len(values) and values[i] is not None:
            node.right = TreeNode(values[i])
            queue.append(node.right)
        i += 1
    return root

# --- tests ---
root1 = build_tree([3, 1, 4, None, 2])
print(kth_smallest_naive(root1, 1))  # 1

root2 = build_tree([5, 3, 6, 2, 4, None, None, 1])
print(kth_smallest_naive(root2, 3))  # 3

root3 = build_tree([3, 1, 4, None, 2])
print(kth_smallest_naive(root3, 3))  # 3

Complexity

  • Time: O(n) — full traversal
  • Space: O(n) — storing all values

2. Iterative Inorder with Early Exit (Optimal)

Intuition

The naive approach traverses the whole tree even after finding the answer. We can stop early using an iterative inorder traversal. Each time we pop a node, that’s the next smallest value. After k pops, we have our answer.

Algorithm

  1. Use a stack and curr = root
  2. Drill left as far as possible, pushing nodes onto the stack
  3. Pop a node — this is the next smallest
  4. Decrement k; if k == 0, return this node’s value
  5. Move to the right child and repeat

Solution

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def kth_smallest(root, k):
    stack = []
    curr = root

    while curr or stack:
        # Go as far left as possible
        while curr:
            stack.append(curr)
            curr = curr.left

        # Pop = visit the next smallest node
        curr = stack.pop()
        k -= 1
        if k == 0:
            return curr.val

        # Move to right child for next iteration
        curr = curr.right

    return -1  # Should never reach here given k <= n

# --- helpers ---
def build_tree(values):
    if not values:
        return None
    root = TreeNode(values[0])
    queue = [root]
    i = 1
    while queue and i < len(values):
        node = queue.pop(0)
        if i < len(values) and values[i] is not None:
            node.left = TreeNode(values[i])
            queue.append(node.left)
        i += 1
        if i < len(values) and values[i] is not None:
            node.right = TreeNode(values[i])
            queue.append(node.right)
        i += 1
    return root

# --- tests ---
root1 = build_tree([3, 1, 4, None, 2])
print(kth_smallest(root1, 1))  # 1
print(kth_smallest(root1, 2))  # 2
print(kth_smallest(root1, 3))  # 3

root2 = build_tree([5, 3, 6, 2, 4, None, None, 1])
print(kth_smallest(root2, 3))  # 3

# Edge case: single node
root3 = build_tree([1])
print(kth_smallest(root3, 1))  # 1

Complexity

  • Time: O(H + k) — H to reach the leftmost leaf, then k steps from there
  • Space: O(H) — stack holds at most one root-to-leaf path

Common Pitfalls

Using k as a global counter carelessly. In the recursive approach, you need to be careful about when to decrement k and how to signal early termination. The iterative approach makes this cleaner since you can just return directly.

Forgetting that BST inorder is ascending. The kth smallest is at index k-1 in a 0-indexed list, or the kth node popped in an iterative inorder. Don’t confuse “kth smallest” with “kth node in some other order.”