Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Sum of All Subsets XOR Total

Difficulty: Easy Source: NeetCode

Problem

The XOR total of an array is defined as the bitwise XOR of all its elements, or 0 if the array is empty.

Given an array nums, return the sum of all XOR totals for every subset of nums.

Note: Subsets with the same elements should be counted multiple times.

Example 1: Input: nums = [1, 3] Output: 6 Explanation: Subsets: {} → 0, {1} → 1, {3} → 3, {1,3} → 2. Sum = 0+1+3+2 = 6.

Example 2: Input: nums = [5, 1, 6] Output: 28

Constraints:

  • 1 <= nums.length <= 12
  • 1 <= nums[i] <= 20

Prerequisites

Before attempting this problem, you should be comfortable with:

  • Bit manipulation — XOR behaves differently from addition; know that a ^ a = 0 and a ^ 0 = a
  • Subsets / power set — understanding that an array of length n has 2^n subsets
  • Backtracking — recursively including or excluding each element

1. Brute Force

Intuition

Enumerate every subset of nums by either including or excluding each element, compute the XOR total of each subset, and add it to a running sum. With at most 12 elements there are only 2^12 = 4096 subsets, so this is totally fine.

Algorithm

  1. Define a recursive function dfs(index, current_xor).
  2. At each index, add current_xor to the total (this accounts for the current subset).
  3. Recurse by including nums[index] (XOR it in) and by excluding it (XOR unchanged).
  4. Stop when index == len(nums).

Solution

def subsetXORSum_brute(nums):
    total = 0

    def dfs(index, current_xor):
        nonlocal total
        if index == len(nums):
            total += current_xor
            return
        # include nums[index]
        dfs(index + 1, current_xor ^ nums[index])
        # exclude nums[index]
        dfs(index + 1, current_xor)

    dfs(0, 0)
    return total


print(subsetXORSum_brute([1, 3]))       # 6
print(subsetXORSum_brute([5, 1, 6]))    # 28
print(subsetXORSum_brute([3, 4, 5, 6, 7, 8]))  # 480

Complexity

  • Time: O(2^n) — one call per subset
  • Space: O(n) — recursion depth

2. Bit Observation

Intuition

There’s a beautiful math shortcut here. Each element nums[i] appears in exactly 2^(n-1) subsets (half of all subsets include any given element). For each bit position, a bit in the final answer is set if and only if at least one element in nums has that bit set — that’s exactly the OR of all elements. Multiplying by 2^(n-1) gives the answer directly. This turns an exponential problem into an O(n) one.

Algorithm

  1. Compute or_total = OR of all elements in nums.
  2. Return or_total * (2 ** (len(nums) - 1)).

Solution

def subsetXORSum(nums):
    or_total = 0
    for num in nums:
        or_total |= num
    return or_total * (2 ** (len(nums) - 1))


print(subsetXORSum([1, 3]))       # 6
print(subsetXORSum([5, 1, 6]))    # 28
print(subsetXORSum([3, 4, 5, 6, 7, 8]))  # 480

Complexity

  • Time: O(n) — single pass to compute OR
  • Space: O(1)

Common Pitfalls

Forgetting the empty subset. The empty subset has an XOR total of 0, which contributes 0 to the sum — so you can safely ignore it, but make sure your recursion base case adds current_xor (which will be 0 for the empty subset) rather than skipping it.

Confusing XOR with OR. The per-subset operation is XOR, but the observation formula uses OR across all elements. These are different things — don’t mix them up.

Off-by-one in the exponent. The formula is OR_total * 2^(n-1), not 2^n. Each element appears in half the subsets, so the multiplier is 2^(n-1).