Construct Quad Tree
Difficulty: Medium Source: NeetCode
Problem
Given a
n * nmatrixgridof0s and1s only, we want to representgridwith a Quad-Tree. Return the root of the Quad-Tree representinggrid.A Quad-Tree is a tree data structure in which each internal node has exactly four children. Each node has two attributes:
val(True if leaf has 1s, or for internal nodes any value) andisLeaf(True if this node is a leaf representing a uniform region).Example 1: Input: grid = [[0,1],[1,0]] Output: [[0,1],[1,0],[0,1],[1,1],[1,0]]
Example 2: Input: grid = [[1,1,1,1,0,0,0,0],[1,1,1,1,0,0,0,0],[1,1,1,1,1,1,1,1],[1,1,1,1,1,1,1,1],[1,1,1,1,0,0,0,0],[1,1,1,1,0,0,0,0],[1,1,1,1,0,0,0,0],[1,1,1,1,0,0,0,0]] Output: [[0,1],[1,1],[0,1],[1,1],[1,0],null,null,null,null,[1,0],[1,0],[1,1],[1,1]]
Constraints:
n == grid.length == grid[i].lengthn == 2^xwhere0 <= x <= 61 <= n <= 64grid[i][j]is either0or1
Prerequisites
Before attempting this problem, you should be comfortable with:
- Divide and Conquer — Splitting a problem into four equal quadrants
- Recursion — Solving the same problem on each quadrant
- Grid Subregions — Working with row/column offsets and sizes
1. Recursive Divide and Conquer
Intuition
Build the quad tree by recursively dividing the grid into four equal quadrants. At each step, check if the current region is uniform (all 0s or all 1s). If it is, create a leaf node. If not, split into four quadrants (top-left, top-right, bottom-left, bottom-right) and recurse on each.
Algorithm
- Check if all values in the current
(row, col, size)region are the same - If uniform → create a leaf node with that value
- If not uniform → split into 4 quadrants, recurse on each, create internal node
Solution
class Node:
def __init__(self, val, isLeaf, topLeft=None, topRight=None,
bottomLeft=None, bottomRight=None):
self.val = val
self.isLeaf = isLeaf
self.topLeft = topLeft
self.topRight = topRight
self.bottomLeft = bottomLeft
self.bottomRight = bottomRight
def construct(grid):
def is_uniform(row, col, size):
"""Check if all cells in the subgrid have the same value."""
val = grid[row][col]
for r in range(row, row + size):
for c in range(col, col + size):
if grid[r][c] != val:
return False
return True
def build(row, col, size):
if is_uniform(row, col, size):
# Leaf node: entire region has the same value
return Node(val=bool(grid[row][col]), isLeaf=True)
half = size // 2
# Split into 4 quadrants
return Node(
val=True, # val for internal nodes can be anything
isLeaf=False,
topLeft=build(row, col, half),
topRight=build(row, col + half, half),
bottomLeft=build(row + half, col, half),
bottomRight=build(row + half, col + half, half),
)
return build(0, 0, len(grid))
# --- tests ---
grid1 = [[0, 1], [1, 0]]
root1 = construct(grid1)
print(root1.isLeaf) # False (not uniform)
print(root1.topLeft.val) # False (0)
print(root1.topRight.val) # True (1)
grid2 = [[1, 1], [1, 1]]
root2 = construct(grid2)
print(root2.isLeaf) # True (uniform)
print(root2.val) # True (all 1s)
grid3 = [
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
]
root3 = construct(grid3)
print(root3.isLeaf) # False
# Verify structure
def print_tree(node, indent=0):
if not node:
return
prefix = " " * indent
if node.isLeaf:
print(f"{prefix}Leaf({int(node.val)})")
else:
print(f"{prefix}Internal")
print_tree(node.topLeft, indent + 1)
print_tree(node.topRight, indent + 1)
print_tree(node.bottomLeft, indent + 1)
print_tree(node.bottomRight, indent + 1)
print_tree(root1)
Complexity
- Time:
O(n² log n)— at each of the log n recursion levels we scan O(n²) cells for uniformity - Space:
O(log n)— recursion depth equals number of times we can halve n
Optimization: Prefix Sums for O(1) Uniformity Check
Intuition
Instead of scanning the region every time to check uniformity, precompute a 2D prefix sum. Then checking if a region is uniform is just checking if the sum equals size * size (all 1s) or 0 (all 0s).
Solution
class Node:
def __init__(self, val, isLeaf, topLeft=None, topRight=None,
bottomLeft=None, bottomRight=None):
self.val = val
self.isLeaf = isLeaf
self.topLeft = topLeft
self.topRight = topRight
self.bottomLeft = bottomLeft
self.bottomRight = bottomRight
def construct_optimized(grid):
n = len(grid)
# Build 2D prefix sums
prefix = [[0] * (n + 1) for _ in range(n + 1)]
for r in range(n):
for c in range(n):
prefix[r + 1][c + 1] = (grid[r][c] + prefix[r][c + 1] +
prefix[r + 1][c] - prefix[r][c])
def region_sum(row, col, size):
r1, c1 = row, col
r2, c2 = row + size, col + size
return prefix[r2][c2] - prefix[r1][c2] - prefix[r2][c1] + prefix[r1][c1]
def build(row, col, size):
total = region_sum(row, col, size)
if total == 0:
return Node(val=False, isLeaf=True)
if total == size * size:
return Node(val=True, isLeaf=True)
half = size // 2
return Node(
val=True,
isLeaf=False,
topLeft=build(row, col, half),
topRight=build(row, col + half, half),
bottomLeft=build(row + half, col, half),
bottomRight=build(row + half, col + half, half),
)
return build(0, 0, n)
# --- test ---
grid = [[0, 1], [1, 0]]
root = construct_optimized(grid)
print(root.isLeaf) # False
print(root.topLeft.val) # False
print(root.topRight.val) # True
Complexity
- Time:
O(n²)— prefix sum build is O(n²); each region check is O(1); we have O(n²/size²) regions at each level summing to O(n²) total - Space:
O(n²)— prefix sum table
Common Pitfalls
Splitting incorrectly. Each quadrant must be exactly size // 2 x size // 2. Off-by-one errors in row/col offsets will corrupt the structure.
Internal node val. For internal nodes, val doesn’t carry a meaningful value per the problem definition — it can be anything. Don’t worry about setting it “correctly” for internal nodes.