Partition to K Equal Sum Subsets
Difficulty: Medium Source: NeetCode
Problem
Given an integer array
numsand an integerk, returntrueif it is possible to divide this array intoknon-empty subsets whose sums are all equal.Example 1: Input:
nums = [4, 3, 2, 3, 5, 2, 1],k = 4Output:trueExplanation: It’s possible to divide it into 4 subsets (5), (1, 4), (2, 3), (2, 3) with equal sums.Example 2: Input:
nums = [1, 2, 3, 4],k = 3Output:falseConstraints:
1 <= k <= len(nums) <= 160 < nums[i] < 10000- Every element of
numswill fit in a 32-bit integer.- It is guaranteed that the answer is unique.
Prerequisites
Before attempting this problem, you should be comfortable with:
- Matchsticks to Square — this is essentially the same problem with k sides instead of 4
- Backtracking — trying to fill k buckets one element at a time
- Pruning — avoiding redundant branches to keep runtime feasible
1. Brute Force
Intuition
This is a generalization of Matchsticks to Square. Each of the k buckets must reach exactly target = sum(nums) / k. Try placing each unused element into each bucket in turn. If a bucket reaches target, it’s done — leave it and move on. If placing the element exceeds target, skip that bucket.
Algorithm
- If
sum(nums) % k != 0: returnFalse. - Set
target = sum(nums) // k. - Sort
numsdescending; ifnums[0] > target: returnFalse. - Define
backtrack(index, buckets). - If all buckets are
target: returnTrue. - Try placing
nums[index]into each bucket (skip if overflow). - Backtrack (remove from bucket) and try the next bucket.
Solution
def canPartitionKSubsets_brute(nums, k):
total = sum(nums)
if total % k != 0:
return False
target = total // k
nums.sort(reverse=True)
if nums[0] > target:
return False
buckets = [0] * k
def backtrack(index):
if index == len(nums):
return all(b == target for b in buckets)
seen = set()
for i in range(k):
if buckets[i] in seen:
continue
if buckets[i] + nums[index] <= target:
seen.add(buckets[i])
buckets[i] += nums[index]
if backtrack(index + 1):
return True
buckets[i] -= nums[index]
return False
return backtrack(0)
print(canPartitionKSubsets_brute([4, 3, 2, 3, 5, 2, 1], 4)) # True
print(canPartitionKSubsets_brute([1, 2, 3, 4], 3)) # False
Complexity
- Time:
O(k^n)— each element can go into k buckets - Space:
O(n + k)— recursion depth plus buckets array
2. Backtracking with Bitmask Memoization
Intuition
Instead of tracking which elements are in each bucket, track which elements have been “used” using a bitmask. A bitmask of n bits where bit i is 1 means nums[i] has been placed. The current bucket’s running sum can be computed from the bitmask. Cache results: if we’ve seen this exact bitmask before and returned False, return False immediately — we don’t need to explore the same set of remaining elements again.
Algorithm
- Compute
target. - Define
backtrack(mask, current_sum):current_sumis the running sum of the current (incomplete) bucket.- If
mask == (1 << n) - 1: all elements placed, returnTrue. - For each
inot inmask:- If
current_sum + nums[i] <= target:- If
current_sum + nums[i] == target: recurse withmask | (1 << i)andcurrent_sum = 0(bucket complete). - Else: recurse with
mask | (1 << i)andcurrent_sum + nums[i].
- If
- If
- Memoize on
mask.
Solution
def canPartitionKSubsets(nums, k):
total = sum(nums)
if total % k != 0:
return False
target = total // k
nums.sort(reverse=True)
if nums[0] > target:
return False
n = len(nums)
memo = {}
def backtrack(mask, current_sum):
if mask == (1 << n) - 1:
return True
if mask in memo:
return memo[mask]
for i in range(n):
if mask & (1 << i): # already used
continue
if current_sum + nums[i] > target:
continue
new_sum = (current_sum + nums[i]) % target # reset to 0 when bucket completes
if backtrack(mask | (1 << i), new_sum):
memo[mask] = True
return True
memo[mask] = False
return False
return backtrack(0, 0)
print(canPartitionKSubsets([4, 3, 2, 3, 5, 2, 1], 4)) # True
print(canPartitionKSubsets([1, 2, 3, 4], 3)) # False
print(canPartitionKSubsets([2, 2, 2, 2, 3, 4, 5], 4)) # False
Complexity
- Time:
O(2^n * n)— at most2^nunique masks, each exploring up tonelements - Space:
O(2^n)— memoization table
Common Pitfalls
Not sorting descending. Large elements fail faster (fewer valid buckets). Sorting descending lets you prune bad branches early in the recursion.
Missing the duplicate-bucket skip. Without seen, you’ll try placing an element into bucket 1 (sum=3) and bucket 2 (sum=3) as separate branches — but they’re identical. Track the set of current bucket sums and skip repeats.
Wrong bitmask memoization. The memoization key is mask alone (which elements remain), not (mask, current_sum). This works because the current bucket’s partial sum equals sum of used elements mod target, which is determined by mask. Using (mask, current_sum) as the key is also valid but caches fewer states.
Forgetting the % target trick. When current_sum + nums[i] == target, the bucket is complete and the new bucket starts at 0. Using (current_sum + nums[i]) % target cleanly handles this: it’s 0 when the bucket completes and current_sum + nums[i] otherwise.