Merge K Sorted Lists
Difficulty: Hard Source: NeetCode
Problem
You are given an array of
klinked-listslists, each linked-list is sorted in ascending order. Merge all the linked-lists into one sorted linked-list and return it.Example 1: Input: lists = [[1,4,5],[1,3,4],[2,6]] Output: [1,1,2,3,4,4,5,6]
Example 2: Input: lists = [] Output: []
Example 3: Input: lists = [[]] Output: []
Constraints:
- k == lists.length
- 0 <= k <= 10^4
- 0 <= lists[i].length <= 500
- -10^4 <= lists[i][j] <= 10^4
- lists[i] is sorted in ascending order
- The sum of all lists[i].length will not exceed 10^4
Prerequisites
Before attempting this problem, you should be comfortable with:
- Merge Two Sorted Lists — the core merge operation (see problem 2)
- Min-Heap / Priority Queue — efficiently finding the smallest among k candidates
- Divide and Conquer — breaking a problem into halves recursively
1. Brute Force — Collect All, Sort, Rebuild
Intuition
Don’t think about the list structure at all. Dump every value from every list into a Python list, sort it, then build a new linked list. Simple but ignores the fact that all input lists are already sorted.
Algorithm
- Walk all k lists and collect all values into one Python list.
- Sort the values.
- Build a new linked list from the sorted values.
Solution
class ListNode:
def __init__(self, val=0, next=None):
self.val = val
self.next = next
def build_list(values):
if not values:
return None
head = ListNode(values[0])
cur = head
for v in values[1:]:
cur.next = ListNode(v)
cur = cur.next
return head
def print_list(head):
result = []
while head:
result.append(str(head.val))
head = head.next
print(" -> ".join(result) if result else "None")
def mergeKLists_brute(lists):
values = []
for head in lists:
cur = head
while cur:
values.append(cur.val)
cur = cur.next
values.sort()
dummy = ListNode(0)
cur = dummy
for v in values:
cur.next = ListNode(v)
cur = cur.next
return dummy.next
# Test cases
lists = [build_list([1,4,5]), build_list([1,3,4]), build_list([2,6])]
print("Output: ", end=""); print_list(mergeKLists_brute(lists)) # 1->1->2->3->4->4->5->6
print("Output: ", end=""); print_list(mergeKLists_brute([])) # None
print("Output: ", end=""); print_list(mergeKLists_brute([None])) # None
Complexity
- Time:
O(N log N)where N = total number of nodes across all lists - Space:
O(N)— storing all values
2. Min-Heap (Optimal)
Intuition
We want to always pick the smallest current head across all k lists. A min-heap gives us the minimum in O(log k). We maintain a heap of (value, index, node) tuples — one entry per list for the current head node.
Why include index? Because when two nodes have the same value, Python will try to compare the ListNode objects — which isn’t defined. The index i acts as a tiebreaker.
lists: [1->4->5], [1->3->4], [2->6]
Initial heap: [(1,0,node1), (1,1,node1), (2,2,node2)]
(sorted by value then index)
Pop (1,0,node1_from_list0): add to result, push (4,0,next_node)
Pop (1,1,node1_from_list1): add to result, push (3,1,next_node)
Pop (2,2,node2_from_list2): add to result, push (6,2,next_node)
...and so on until heap is empty
Algorithm
- Initialize a min-heap with
(node.val, i, node)for each non-empty list head. - While the heap is not empty:
- Pop the smallest
(val, i, node). - Attach
nodeto the result list. - If
node.nextexists, push(node.next.val, i, node.next)onto the heap.
- Pop the smallest
- Return
dummy.next.
Solution
import heapq
class ListNode:
def __init__(self, val=0, next=None):
self.val = val
self.next = next
def build_list(values):
if not values:
return None
head = ListNode(values[0])
cur = head
for v in values[1:]:
cur.next = ListNode(v)
cur = cur.next
return head
def print_list(head):
result = []
while head:
result.append(str(head.val))
head = head.next
print(" -> ".join(result) if result else "None")
def mergeKLists(lists):
dummy = ListNode(0)
cur = dummy
# Initialize heap with (value, list_index, node)
# list_index breaks ties when values are equal
heap = []
for i, head in enumerate(lists):
if head:
heapq.heappush(heap, (head.val, i, head))
while heap:
val, i, node = heapq.heappop(heap)
cur.next = node
cur = cur.next
if node.next:
heapq.heappush(heap, (node.next.val, i, node.next))
return dummy.next
# Test cases
lists = [build_list([1,4,5]), build_list([1,3,4]), build_list([2,6])]
print("Output: ", end=""); print_list(mergeKLists(lists)) # 1->1->2->3->4->4->5->6
print("Output: ", end=""); print_list(mergeKLists([])) # None
print("Output: ", end=""); print_list(mergeKLists([None])) # None
# Single list
print("Output: ", end=""); print_list(mergeKLists([build_list([1,2,3])])) # 1->2->3
# All single-element lists
lists2 = [build_list([3]), build_list([1]), build_list([2])]
print("Output: ", end=""); print_list(mergeKLists(lists2)) # 1->2->3
Complexity
- Time:
O(N log k)— N total nodes, each pushed/popped from a heap of size at most k - Space:
O(k)— heap size is at most k at any time
3. Divide and Conquer
Intuition
Instead of merging all lists at once, merge them pairwise — like a tournament bracket. Pair up lists [0,1], [2,3], [4,5], etc., merge each pair, then repeat with the halved set of lists. After log k rounds, only one merged list remains. This reuses our O(m+n) merge-two-lists function efficiently.
Round 1: merge([0,1]), merge([2,3]), merge([4,5]) → k/2 lists
Round 2: merge([0+1, 2+3]), merge([4+5]) → k/4 lists
Round 3: merge(all) → 1 list
Total work: each node participates in O(log k) merges → O(N log k) overall.
Solution
class ListNode:
def __init__(self, val=0, next=None):
self.val = val
self.next = next
def build_list(values):
if not values:
return None
head = ListNode(values[0])
cur = head
for v in values[1:]:
cur.next = ListNode(v)
cur = cur.next
return head
def print_list(head):
result = []
while head:
result.append(str(head.val))
head = head.next
print(" -> ".join(result) if result else "None")
def mergeTwoLists(l1, l2):
dummy = ListNode(0)
cur = dummy
while l1 and l2:
if l1.val <= l2.val:
cur.next = l1
l1 = l1.next
else:
cur.next = l2
l2 = l2.next
cur = cur.next
cur.next = l1 if l1 else l2
return dummy.next
def mergeKLists_dc(lists):
if not lists:
return None
while len(lists) > 1:
merged = []
for i in range(0, len(lists), 2):
l1 = lists[i]
l2 = lists[i + 1] if i + 1 < len(lists) else None
merged.append(mergeTwoLists(l1, l2))
lists = merged
return lists[0]
# Test cases
lists = [build_list([1,4,5]), build_list([1,3,4]), build_list([2,6])]
print("Output: ", end=""); print_list(mergeKLists_dc(lists)) # 1->1->2->3->4->4->5->6
print("Output: ", end=""); print_list(mergeKLists_dc([])) # None
print("Output: ", end=""); print_list(mergeKLists_dc([None])) # None
lists2 = [build_list([3]), build_list([1]), build_list([2])]
print("Output: ", end=""); print_list(mergeKLists_dc(lists2)) # 1->2->3
Complexity
- Time:
O(N log k)— log k rounds of merging, each round processes all N nodes - Space:
O(log k)— recursion depth (or O(1) if done iteratively as above)
Common Pitfalls
Using (val, node) in the heap without an index tiebreaker. If two nodes have the same value, Python tries to compare ListNode objects. Since ListNode doesn’t implement __lt__, this raises a TypeError. Always include the list index as a tiebreaker: (val, i, node).
Not checking if lists[i] is None before pushing to heap. Some lists might be empty. Only push to the heap if the head node is not None.
Naive sequential merging: O(kN) instead of O(N log k). Merging list by list (merge first two, merge result with third, etc.) means the first merged list gets processed k times. Always use the heap or divide-and-conquer approach to avoid this.