Skip to content

Heap Merge K Sorted

Table of Contents

23. Merge k Sorted Lists

"""
-   Prerequisite: 21. Merge Two Sorted Lists
-   Video explanation: [23. Merge K Sorted Lists - NeetCode](https://youtu.be/q5a5OiGbT6Q?si=SQ2dCvsYQ3LQctPh)
"""

import copy
import heapq
from typing import List, Optional

from leetpattern.utils import ListNode, list_from_array, list_to_array


def merge_k_lists_divide_conquer(
    lists: List[Optional[ListNode]],
) -> Optional[ListNode]:
    if not lists or len(lists) == 0:
        return None

    def mergeTwo(l1, l2):
        dummy = ListNode()
        cur = dummy

        while l1 and l2:
            if l1.val < l2.val:
                cur.next = l1
                l1 = l1.next
            else:
                cur.next = l2
                l2 = l2.next

            cur = cur.next

        cur.next = l1 if l1 else l2

        return dummy.next

    while len(lists) > 1:
        merged = []

        for i in range(0, len(lists), 2):
            l1 = lists[i]
            l2 = lists[i + 1] if i + 1 < len(lists) else None
            merged.append(mergeTwo(l1, l2))

        lists = merged

    return lists[0]


def merge_k_lists_heap(lists: List[Optional[ListNode]]) -> Optional[ListNode]:
    dummy = ListNode()
    cur = dummy

    min_heap = []  # (val, idx, node)

    for idx, head in enumerate(lists):
        if head:
            heapq.heappush(min_heap, (head.val, idx, head))

    while min_heap:
        _, idx, node = heapq.heappop(min_heap)
        cur.next = node
        cur = cur.next

        node = node.next
        if node:
            heapq.heappush(min_heap, (node.val, idx, node))

    return dummy.next


def test_merge_k_lists() -> None:
    n1 = list_from_array([1, 4])
    n2 = list_from_array([1, 3])
    n3 = list_from_array([2, 6])
    lists = [n1, n2, n3]
    lists1 = copy.deepcopy(lists)
    lists2 = copy.deepcopy(lists)
    assert (list_to_array(merge_k_lists_divide_conquer(lists1))) == [
        1,
        1,
        2,
        3,
        4,
        6,
    ]
    assert (list_to_array(merge_k_lists_heap(lists2))) == [1, 1, 2, 3, 4, 6]

373. Find K Pairs with Smallest Sums

import heapq
from typing import List


# Heap - Merge K Sorted
def kSmallestPairs(
    nums1: List[int], nums2: List[int], k: int
) -> List[List[int]]:
    if not nums1 or not nums2 or k <= 0:
        return []

    res = []
    min_heap = []

    for j in range(min(k, len(nums2))):
        heapq.heappush(min_heap, (nums1[0] + nums2[j], 0, j))

    while k > 0 and min_heap:
        _, i, j = heapq.heappop(min_heap)
        res.append([nums1[i], nums2[j]])
        k -= 1

        if i + 1 < len(nums1):
            heapq.heappush(min_heap, (nums1[i + 1] + nums2[j], i + 1, j))

    return res


if __name__ == "__main__":
    nums1 = [1, 2, 4, 5, 6]
    nums2 = [3, 5, 7, 9]
    k = 3
    assert kSmallestPairs(nums1, nums2, k) == [[1, 3], [2, 3], [1, 5]]

378. Kth Smallest Element in a Sorted Matrix

"""
-   Given an `n x n` matrix where each of the rows and columns are sorted in ascending order, return the `k-th` smallest element in the matrix.
"""

from heapq import heapify, heappop, heappush
from typing import List


# Heap - Merge K Sorted
def kthSmallestHeap(matrix: List[List[int]], k: int) -> int:
    n = len(matrix)
    heap = [(matrix[i][0], i, 0) for i in range(n)]
    heapify(heap)

    for _ in range(k - 1):
        _, row, col = heappop(heap)

        if col + 1 < n:
            heappush(heap, (matrix[row][col + 1], row, col + 1))

    return heappop(heap)[0]


# Binary Search
def kthSmallestBinarySearch(matrix: List[List[int]], k: int) -> int:
    n = len(matrix)

    def check(mid):
        i, j = n - 1, 0
        num = 0

        while i >= 0 and j < n:
            if matrix[i][j] <= mid:
                num += i + 1
                j += 1
            else:
                i -= 1

        return num >= k

    left, right = matrix[0][0], matrix[-1][-1]

    while left < right:
        mid = (left + right) // 2
        if check(mid):
            right = mid
        else:
            left = mid + 1

    return left


matrix = [[1, 5, 9], [10, 11, 13], [12, 13, 15]]
k = 8

print(kthSmallestHeap(matrix, k))  # 13
print(kthSmallestBinarySearch(matrix, k))  # 13

Comments