Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Partition to K Equal Sum Subsets

Difficulty: Medium Source: NeetCode

Problem

Given an integer array nums and an integer k, return true if it is possible to divide this array into k non-empty subsets whose sums are all equal.

Example 1: Input: nums = [4, 3, 2, 3, 5, 2, 1], k = 4 Output: true Explanation: It’s possible to divide it into 4 subsets (5), (1, 4), (2, 3), (2, 3) with equal sums.

Example 2: Input: nums = [1, 2, 3, 4], k = 3 Output: false

Constraints:

  • 1 <= k <= len(nums) <= 16
  • 0 < nums[i] < 10000
  • Every element of nums will fit in a 32-bit integer.
  • It is guaranteed that the answer is unique.

Prerequisites

Before attempting this problem, you should be comfortable with:

  • Matchsticks to Square — this is essentially the same problem with k sides instead of 4
  • Backtracking — trying to fill k buckets one element at a time
  • Pruning — avoiding redundant branches to keep runtime feasible

1. Brute Force

Intuition

This is a generalization of Matchsticks to Square. Each of the k buckets must reach exactly target = sum(nums) / k. Try placing each unused element into each bucket in turn. If a bucket reaches target, it’s done — leave it and move on. If placing the element exceeds target, skip that bucket.

Algorithm

  1. If sum(nums) % k != 0: return False.
  2. Set target = sum(nums) // k.
  3. Sort nums descending; if nums[0] > target: return False.
  4. Define backtrack(index, buckets).
  5. If all buckets are target: return True.
  6. Try placing nums[index] into each bucket (skip if overflow).
  7. Backtrack (remove from bucket) and try the next bucket.

Solution

def canPartitionKSubsets_brute(nums, k):
    total = sum(nums)
    if total % k != 0:
        return False
    target = total // k
    nums.sort(reverse=True)
    if nums[0] > target:
        return False

    buckets = [0] * k

    def backtrack(index):
        if index == len(nums):
            return all(b == target for b in buckets)
        seen = set()
        for i in range(k):
            if buckets[i] in seen:
                continue
            if buckets[i] + nums[index] <= target:
                seen.add(buckets[i])
                buckets[i] += nums[index]
                if backtrack(index + 1):
                    return True
                buckets[i] -= nums[index]
        return False

    return backtrack(0)


print(canPartitionKSubsets_brute([4, 3, 2, 3, 5, 2, 1], 4))  # True
print(canPartitionKSubsets_brute([1, 2, 3, 4], 3))            # False

Complexity

  • Time: O(k^n) — each element can go into k buckets
  • Space: O(n + k) — recursion depth plus buckets array

2. Backtracking with Bitmask Memoization

Intuition

Instead of tracking which elements are in each bucket, track which elements have been “used” using a bitmask. A bitmask of n bits where bit i is 1 means nums[i] has been placed. The current bucket’s running sum can be computed from the bitmask. Cache results: if we’ve seen this exact bitmask before and returned False, return False immediately — we don’t need to explore the same set of remaining elements again.

Algorithm

  1. Compute target.
  2. Define backtrack(mask, current_sum):
    • current_sum is the running sum of the current (incomplete) bucket.
    • If mask == (1 << n) - 1: all elements placed, return True.
    • For each i not in mask:
      • If current_sum + nums[i] <= target:
        • If current_sum + nums[i] == target: recurse with mask | (1 << i) and current_sum = 0 (bucket complete).
        • Else: recurse with mask | (1 << i) and current_sum + nums[i].
  3. Memoize on mask.

Solution

def canPartitionKSubsets(nums, k):
    total = sum(nums)
    if total % k != 0:
        return False
    target = total // k
    nums.sort(reverse=True)
    if nums[0] > target:
        return False

    n = len(nums)
    memo = {}

    def backtrack(mask, current_sum):
        if mask == (1 << n) - 1:
            return True
        if mask in memo:
            return memo[mask]

        for i in range(n):
            if mask & (1 << i):  # already used
                continue
            if current_sum + nums[i] > target:
                continue
            new_sum = (current_sum + nums[i]) % target  # reset to 0 when bucket completes
            if backtrack(mask | (1 << i), new_sum):
                memo[mask] = True
                return True

        memo[mask] = False
        return False

    return backtrack(0, 0)


print(canPartitionKSubsets([4, 3, 2, 3, 5, 2, 1], 4))  # True
print(canPartitionKSubsets([1, 2, 3, 4], 3))            # False
print(canPartitionKSubsets([2, 2, 2, 2, 3, 4, 5], 4))  # False

Complexity

  • Time: O(2^n * n) — at most 2^n unique masks, each exploring up to n elements
  • Space: O(2^n) — memoization table

Common Pitfalls

Not sorting descending. Large elements fail faster (fewer valid buckets). Sorting descending lets you prune bad branches early in the recursion.

Missing the duplicate-bucket skip. Without seen, you’ll try placing an element into bucket 1 (sum=3) and bucket 2 (sum=3) as separate branches — but they’re identical. Track the set of current bucket sums and skip repeats.

Wrong bitmask memoization. The memoization key is mask alone (which elements remain), not (mask, current_sum). This works because the current bucket’s partial sum equals sum of used elements mod target, which is determined by mask. Using (mask, current_sum) as the key is also valid but caches fewer states.

Forgetting the % target trick. When current_sum + nums[i] == target, the bucket is complete and the new bucket starts at 0. Using (current_sum + nums[i]) % target cleanly handles this: it’s 0 when the bucket completes and current_sum + nums[i] otherwise.