Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Construct Quad Tree

Difficulty: Medium Source: NeetCode

Problem

Given a n * n matrix grid of 0s and 1s only, we want to represent grid with a Quad-Tree. Return the root of the Quad-Tree representing grid.

A Quad-Tree is a tree data structure in which each internal node has exactly four children. Each node has two attributes: val (True if leaf has 1s, or for internal nodes any value) and isLeaf (True if this node is a leaf representing a uniform region).

Example 1: Input: grid = [[0,1],[1,0]] Output: [[0,1],[1,0],[0,1],[1,1],[1,0]]

Example 2: Input: grid = [[1,1,1,1,0,0,0,0],[1,1,1,1,0,0,0,0],[1,1,1,1,1,1,1,1],[1,1,1,1,1,1,1,1],[1,1,1,1,0,0,0,0],[1,1,1,1,0,0,0,0],[1,1,1,1,0,0,0,0],[1,1,1,1,0,0,0,0]] Output: [[0,1],[1,1],[0,1],[1,1],[1,0],null,null,null,null,[1,0],[1,0],[1,1],[1,1]]

Constraints:

  • n == grid.length == grid[i].length
  • n == 2^x where 0 <= x <= 6
  • 1 <= n <= 64
  • grid[i][j] is either 0 or 1

Prerequisites

Before attempting this problem, you should be comfortable with:

  • Divide and Conquer — Splitting a problem into four equal quadrants
  • Recursion — Solving the same problem on each quadrant
  • Grid Subregions — Working with row/column offsets and sizes

1. Recursive Divide and Conquer

Intuition

Build the quad tree by recursively dividing the grid into four equal quadrants. At each step, check if the current region is uniform (all 0s or all 1s). If it is, create a leaf node. If not, split into four quadrants (top-left, top-right, bottom-left, bottom-right) and recurse on each.

Algorithm

  1. Check if all values in the current (row, col, size) region are the same
  2. If uniform → create a leaf node with that value
  3. If not uniform → split into 4 quadrants, recurse on each, create internal node

Solution

class Node:
    def __init__(self, val, isLeaf, topLeft=None, topRight=None,
                 bottomLeft=None, bottomRight=None):
        self.val = val
        self.isLeaf = isLeaf
        self.topLeft = topLeft
        self.topRight = topRight
        self.bottomLeft = bottomLeft
        self.bottomRight = bottomRight

def construct(grid):
    def is_uniform(row, col, size):
        """Check if all cells in the subgrid have the same value."""
        val = grid[row][col]
        for r in range(row, row + size):
            for c in range(col, col + size):
                if grid[r][c] != val:
                    return False
        return True

    def build(row, col, size):
        if is_uniform(row, col, size):
            # Leaf node: entire region has the same value
            return Node(val=bool(grid[row][col]), isLeaf=True)

        half = size // 2
        # Split into 4 quadrants
        return Node(
            val=True,  # val for internal nodes can be anything
            isLeaf=False,
            topLeft=build(row, col, half),
            topRight=build(row, col + half, half),
            bottomLeft=build(row + half, col, half),
            bottomRight=build(row + half, col + half, half),
        )

    return build(0, 0, len(grid))

# --- tests ---
grid1 = [[0, 1], [1, 0]]
root1 = construct(grid1)
print(root1.isLeaf)       # False (not uniform)
print(root1.topLeft.val)  # False (0)
print(root1.topRight.val) # True (1)

grid2 = [[1, 1], [1, 1]]
root2 = construct(grid2)
print(root2.isLeaf)  # True (uniform)
print(root2.val)     # True (all 1s)

grid3 = [
    [1, 1, 1, 1, 0, 0, 0, 0],
    [1, 1, 1, 1, 0, 0, 0, 0],
    [1, 1, 1, 1, 1, 1, 1, 1],
    [1, 1, 1, 1, 1, 1, 1, 1],
    [1, 1, 1, 1, 0, 0, 0, 0],
    [1, 1, 1, 1, 0, 0, 0, 0],
    [1, 1, 1, 1, 0, 0, 0, 0],
    [1, 1, 1, 1, 0, 0, 0, 0],
]
root3 = construct(grid3)
print(root3.isLeaf)  # False

# Verify structure
def print_tree(node, indent=0):
    if not node:
        return
    prefix = "  " * indent
    if node.isLeaf:
        print(f"{prefix}Leaf({int(node.val)})")
    else:
        print(f"{prefix}Internal")
        print_tree(node.topLeft, indent + 1)
        print_tree(node.topRight, indent + 1)
        print_tree(node.bottomLeft, indent + 1)
        print_tree(node.bottomRight, indent + 1)

print_tree(root1)

Complexity

  • Time: O(n² log n) — at each of the log n recursion levels we scan O(n²) cells for uniformity
  • Space: O(log n) — recursion depth equals number of times we can halve n

Optimization: Prefix Sums for O(1) Uniformity Check

Intuition

Instead of scanning the region every time to check uniformity, precompute a 2D prefix sum. Then checking if a region is uniform is just checking if the sum equals size * size (all 1s) or 0 (all 0s).

Solution

class Node:
    def __init__(self, val, isLeaf, topLeft=None, topRight=None,
                 bottomLeft=None, bottomRight=None):
        self.val = val
        self.isLeaf = isLeaf
        self.topLeft = topLeft
        self.topRight = topRight
        self.bottomLeft = bottomLeft
        self.bottomRight = bottomRight

def construct_optimized(grid):
    n = len(grid)
    # Build 2D prefix sums
    prefix = [[0] * (n + 1) for _ in range(n + 1)]
    for r in range(n):
        for c in range(n):
            prefix[r + 1][c + 1] = (grid[r][c] + prefix[r][c + 1] +
                                     prefix[r + 1][c] - prefix[r][c])

    def region_sum(row, col, size):
        r1, c1 = row, col
        r2, c2 = row + size, col + size
        return prefix[r2][c2] - prefix[r1][c2] - prefix[r2][c1] + prefix[r1][c1]

    def build(row, col, size):
        total = region_sum(row, col, size)
        if total == 0:
            return Node(val=False, isLeaf=True)
        if total == size * size:
            return Node(val=True, isLeaf=True)

        half = size // 2
        return Node(
            val=True,
            isLeaf=False,
            topLeft=build(row, col, half),
            topRight=build(row, col + half, half),
            bottomLeft=build(row + half, col, half),
            bottomRight=build(row + half, col + half, half),
        )

    return build(0, 0, n)

# --- test ---
grid = [[0, 1], [1, 0]]
root = construct_optimized(grid)
print(root.isLeaf)         # False
print(root.topLeft.val)    # False
print(root.topRight.val)   # True

Complexity

  • Time: O(n²) — prefix sum build is O(n²); each region check is O(1); we have O(n²/size²) regions at each level summing to O(n²) total
  • Space: O(n²) — prefix sum table

Common Pitfalls

Splitting incorrectly. Each quadrant must be exactly size // 2 x size // 2. Off-by-one errors in row/col offsets will corrupt the structure.

Internal node val. For internal nodes, val doesn’t carry a meaningful value per the problem definition — it can be anything. Don’t worry about setting it “correctly” for internal nodes.