Sum of All Subsets XOR Total
Difficulty: Easy Source: NeetCode
Problem
The XOR total of an array is defined as the bitwise XOR of all its elements, or
0if the array is empty.Given an array
nums, return the sum of all XOR totals for every subset ofnums.Note: Subsets with the same elements should be counted multiple times.
Example 1: Input:
nums = [1, 3]Output:6Explanation: Subsets:{}→ 0,{1}→ 1,{3}→ 3,{1,3}→ 2. Sum = 0+1+3+2 = 6.Example 2: Input:
nums = [5, 1, 6]Output:28Constraints:
1 <= nums.length <= 121 <= nums[i] <= 20
Prerequisites
Before attempting this problem, you should be comfortable with:
- Bit manipulation — XOR behaves differently from addition; know that
a ^ a = 0anda ^ 0 = a - Subsets / power set — understanding that an array of length
nhas2^nsubsets - Backtracking — recursively including or excluding each element
1. Brute Force
Intuition
Enumerate every subset of nums by either including or excluding each element, compute the XOR total of each subset, and add it to a running sum. With at most 12 elements there are only 2^12 = 4096 subsets, so this is totally fine.
Algorithm
- Define a recursive function
dfs(index, current_xor). - At each index, add
current_xorto the total (this accounts for the current subset). - Recurse by including
nums[index](XOR it in) and by excluding it (XOR unchanged). - Stop when
index == len(nums).
Solution
def subsetXORSum_brute(nums):
total = 0
def dfs(index, current_xor):
nonlocal total
if index == len(nums):
total += current_xor
return
# include nums[index]
dfs(index + 1, current_xor ^ nums[index])
# exclude nums[index]
dfs(index + 1, current_xor)
dfs(0, 0)
return total
print(subsetXORSum_brute([1, 3])) # 6
print(subsetXORSum_brute([5, 1, 6])) # 28
print(subsetXORSum_brute([3, 4, 5, 6, 7, 8])) # 480
Complexity
- Time:
O(2^n)— one call per subset - Space:
O(n)— recursion depth
2. Bit Observation
Intuition
There’s a beautiful math shortcut here. Each element nums[i] appears in exactly 2^(n-1) subsets (half of all subsets include any given element). For each bit position, a bit in the final answer is set if and only if at least one element in nums has that bit set — that’s exactly the OR of all elements. Multiplying by 2^(n-1) gives the answer directly. This turns an exponential problem into an O(n) one.
Algorithm
- Compute
or_total = OR of all elements in nums. - Return
or_total * (2 ** (len(nums) - 1)).
Solution
def subsetXORSum(nums):
or_total = 0
for num in nums:
or_total |= num
return or_total * (2 ** (len(nums) - 1))
print(subsetXORSum([1, 3])) # 6
print(subsetXORSum([5, 1, 6])) # 28
print(subsetXORSum([3, 4, 5, 6, 7, 8])) # 480
Complexity
- Time:
O(n)— single pass to compute OR - Space:
O(1)
Common Pitfalls
Forgetting the empty subset. The empty subset has an XOR total of 0, which contributes 0 to the sum — so you can safely ignore it, but make sure your recursion base case adds current_xor (which will be 0 for the empty subset) rather than skipping it.
Confusing XOR with OR. The per-subset operation is XOR, but the observation formula uses OR across all elements. These are different things — don’t mix them up.
Off-by-one in the exponent. The formula is OR_total * 2^(n-1), not 2^n. Each element appears in half the subsets, so the multiplier is 2^(n-1).