Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

BST Sets and Maps

Python’s set and dict operations — under the hood, this is often a tree.

When you call my_set.add(x) or my_dict[key] = value in Python, the built-in types use hash tables internally. Hash tables are fast, but they cannot tell you the smallest key, the largest key, or all keys in a given range — they have no concept of order. A BST-backed set or map is slightly slower per operation, but it buys you something invaluable: everything stays sorted, always.

A BST as a sorted set

A set stores unique values and supports three core operations: insert, contains, and remove. A BST handles all three in O(log n) and also gives you sorted iteration for free.

flowchart TD
    Root["8 (root)"]
    Root --> L3["3"]
    Root --> R10["10"]
    L3 --> LL1["1"]
    L3 --> LR6["6"]
    LR6 --> LRL4["4"]
    LR6 --> LRR7["7"]
    R10 --> RR14["14"]

In-order traversal reads this tree as: 1, 3, 4, 6, 7, 8, 10, 14 — a sorted set.

class TreeNode:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None


class BSTSet:
    """A sorted set backed by a Binary Search Tree."""

    def __init__(self):
        self._root = None

    def add(self, value):
        self._root = self._insert(self._root, value)

    def contains(self, value):
        return self._search(self._root, value)

    def remove(self, value):
        self._root = self._delete(self._root, value)

    def to_sorted_list(self):
        """In-order traversal — always returns values in ascending order."""
        result = []
        self._inorder(self._root, result)
        return result

    def minimum(self):
        node = self._root
        while node and node.left:
            node = node.left
        return node.value if node else None

    def maximum(self):
        node = self._root
        while node and node.right:
            node = node.right
        return node.value if node else None

    def range_query(self, low, high):
        """Return all values where low <= value <= high, in sorted order."""
        result = []
        self._range(self._root, low, high, result)
        return result

    # ---- internal helpers ----

    def _insert(self, node, value):
        if node is None:
            return TreeNode(value)
        if value < node.value:
            node.left = self._insert(node.left, value)
        elif value > node.value:
            node.right = self._insert(node.right, value)
        return node

    def _search(self, node, value):
        if node is None:
            return False
        if value == node.value:
            return True
        if value < node.value:
            return self._search(node.left, value)
        return self._search(node.right, value)

    def _find_min(self, node):
        while node.left:
            node = node.left
        return node

    def _delete(self, node, value):
        if node is None:
            return None
        if value < node.value:
            node.left = self._delete(node.left, value)
        elif value > node.value:
            node.right = self._delete(node.right, value)
        else:
            if node.left is None:
                return node.right
            if node.right is None:
                return node.left
            successor = self._find_min(node.right)
            node.value = successor.value
            node.right = self._delete(node.right, successor.value)
        return node

    def _inorder(self, node, result):
        if node:
            self._inorder(node.left, result)
            result.append(node.value)
            self._inorder(node.right, result)

    def _range(self, node, low, high, result):
        if node is None:
            return
        if node.value > low:      # there might be values >= low in left subtree
            self._range(node.left, low, high, result)
        if low <= node.value <= high:
            result.append(node.value)
        if node.value < high:     # there might be values <= high in right subtree
            self._range(node.right, low, high, result)


s = BSTSet()
for val in [8, 3, 10, 1, 6, 4, 7, 14]:
    s.add(val)

print("Sorted contents:    ", s.to_sorted_list())
print("Contains 6:         ", s.contains(6))
print("Contains 5:         ", s.contains(5))
print("Minimum:            ", s.minimum())
print("Maximum:            ", s.maximum())
print("Range [4, 8]:       ", s.range_query(4, 8))

s.remove(6)
print("After removing 6:   ", s.to_sorted_list())

A BST as an ordered map (dictionary)

A map stores key-value pairs and supports put(key, value) and get(key). The BST is ordered by key, and each node carries an associated value as cargo.

class MapNode:
    def __init__(self, key, value):
        self.key = key
        self.value = value
        self.left = None
        self.right = None


class BSTMap:
    """
    An ordered map (dictionary) backed by a BST.
    Keys are kept sorted, so you get sorted iteration and range queries
    that a plain Python dict cannot provide.
    """

    def __init__(self):
        self._root = None

    def put(self, key, value):
        """Insert or update a key-value pair."""
        self._root = self._insert(self._root, key, value)

    def get(self, key, default=None):
        """Return the value for key, or default if not found."""
        node = self._search(self._root, key)
        return node.value if node else default

    def keys_sorted(self):
        """Return all keys in ascending order."""
        result = []
        self._inorder(self._root, result)
        return result

    def items_in_range(self, low_key, high_key):
        """Return (key, value) pairs where low_key <= key <= high_key, sorted."""
        result = []
        self._range(self._root, low_key, high_key, result)
        return result

    # ---- internal helpers ----

    def _insert(self, node, key, value):
        if node is None:
            return MapNode(key, value)
        if key < node.key:
            node.left = self._insert(node.left, key, value)
        elif key > node.key:
            node.right = self._insert(node.right, key, value)
        else:
            node.value = value   # update existing key
        return node

    def _search(self, node, key):
        if node is None:
            return None
        if key == node.key:
            return node
        if key < node.key:
            return self._search(node.left, key)
        return self._search(node.right, key)

    def _inorder(self, node, result):
        if node:
            self._inorder(node.left, result)
            result.append(node.key)
            self._inorder(node.right, result)

    def _range(self, node, low, high, result):
        if node is None:
            return
        if node.key > low:
            self._range(node.left, low, high, result)
        if low <= node.key <= high:
            result.append((node.key, node.value))
        if node.key < high:
            self._range(node.right, low, high, result)


# Store word frequencies
freq_map = BSTMap()
words = ["banana", "apple", "cherry", "apple", "banana", "banana", "date"]
for word in words:
    count = freq_map.get(word, 0)
    freq_map.put(word, count + 1)

print("Keys in alphabetical order:", freq_map.keys_sorted())
print("Frequency of 'banana':     ", freq_map.get("banana"))
print("Frequency of 'grape':      ", freq_map.get("grape", 0))
print("Words a–c:                 ", freq_map.items_in_range("a", "c"))

Hash table vs BST-backed map — a direct comparison

# Python's built-in dict is hash-based — fast, but unordered
plain_dict = {"banana": 3, "apple": 2, "cherry": 1, "date": 1}

print("dict keys (arbitrary order):", list(plain_dict.keys()))

# To get sorted keys from a plain dict you must sort explicitly — O(n log n)
print("dict sorted keys:           ", sorted(plain_dict.keys()))

# A BST map keeps keys sorted automatically — each put() maintains order
# (Using the BSTMap from the cell above — paste both cells together to run)

class MapNode:
    def __init__(self, key, value):
        self.key = key
        self.value = value
        self.left = None
        self.right = None

class BSTMap:
    def __init__(self):
        self._root = None
    def put(self, key, value):
        self._root = self._ins(self._root, key, value)
    def _ins(self, node, key, value):
        if node is None:
            return MapNode(key, value)
        if key < node.key:
            node.left = self._ins(node.left, key, value)
        elif key > node.key:
            node.right = self._ins(node.right, key, value)
        else:
            node.value = value
        return node
    def keys_sorted(self):
        r = []
        self._io(self._root, r)
        return r
    def _io(self, node, r):
        if node:
            self._io(node.left, r)
            r.append(node.key)
            self._io(node.right, r)

bst_map = BSTMap()
for k, v in plain_dict.items():
    bst_map.put(k, v)

print("BSTMap keys (always sorted):", bst_map.keys_sorted())
FeatureHash dict (dict)BST map (BSTMap)
Average lookupO(1)O(log n)
Sorted iterationNot supportedO(n) — free via in-order
Range queryNot supportedO(log n + k)
Min / Max keyNot supportedO(log n)
Worst-case lookupO(n) — hash collisionO(n) — unbalanced tree

Real-world implementations

Python’s sortedcontainers.SortedList / SortedDict — A third-party library that gives you a sorted list and sorted dictionary with BST-like ordered operations. Widely used when you need range queries or sorted iteration alongside fast lookups.

Java’s TreeMap and TreeSet — Backed by a Red-Black tree (a self-balancing BST), these are the standard Java ordered-map and ordered-set. firstKey(), lastKey(), subMap(from, to) are all powered by the BST structure.

Database ordered indexes — A database CREATE INDEX ON orders(order_date) builds a B-Tree (a disk-friendly generalisation of a BST). Queries like WHERE order_date BETWEEN '2024-01-01' AND '2024-06-30' use the tree to jump directly to the start of the range instead of scanning the whole table.

C++ std::map and std::set — Also backed by Red-Black trees, giving O(log n) insert, lookup, and delete with full ordered iteration.