Crushing Tech Interviews with the Two Heaps Pattern

·

4 min read

If we have a problem where we're interested in knowing the smallest in element in one part of a given set of elements and the biggest element in the other part, we can use the two heaps pattern.

Lets work through this problem together:

Implement a class that can calculate the median of a number stream. The class should have two methods: add_num(adds a number to the class) and find_median (finds the median of the stored numbers).

Note: If the total numbers inserted into the class is even, you should take the average of the two middle numbers

Example:

1. add_num(5)
2. add_num(1)
3. find_median() -> output: (5+1)/2 = 3
4. add_num(8)
5. find_median() -> output: 5
6. add_num(6)
7. find_median() -> output: (6+5)/2 = 5.5

Can you figure out a brute force solution?

We can solve this inefficiently by maintaining an ordered list of values and then returning the median whenever we need to. Unfortunately, inserting a number into a sorted list will take O(N) and we can do better!

Notice that we don't need a fully ordered list; we only need to identify the middle element(s).

We notice that for every middle element, half of the list will be smaller than or equal to the middle element and half will be greater than or equal to the middle element.

So why don't we have two lists? One for elements that are smaller (lets call it small_elements ) and one for elements that are larger (lets call it large_elements). If we have two lists, the median of the two lists will be either the largest in small_elements or the smallest in the large_element or if the total elements are even, then it would be the average of the two numbers.

So how can we maintain the smallest/largest in a list? Using a heap of-course!

  • small_elements will be stored as a max heap.
  • large_elements will be stored as a min heap.

But why does this make our solution better? This solution is better because inserting into a heap is a O(logN) operation, rather than an O(N) operation that we were doing before.

Lets look at a diagram and see how this works (given the example above):

image.png

  1. We can insert into the max heap if the top element (i.e. the greatest element) is smaller than the element we're inserting. After each insertion, we need to remember to balance the heaps so that we have an even number of elements in each. If it's an odd number, lets leave more elements in the max-heap rather than the min-heap. (You can decide to go the other way but the implementation will slightly differ)
  2. As 1 is smaller than 5 we can add it to the max heap. Now that the heaps are in a state of unbalance; theres two elements in max heap and none in min heap. We have to balance the heaps. So we move the 5 to the min heap.
  3. We find the median. In this case it's (1+5 ) / 2 = 3 because we have an even number of total elements.
  4. 8 is larger than the top element of the max heap 1, so we add it to the min heap and then balance. Now that the min heap is larger than the max heap, we balance the other way; we move the smallest number from the min heap and insert it into the max heap.
  5. We find median again which in this case is just the top element of the max heap; 5
  6. Insert 6 into the min heap as it's greater than 5. The heaps are balanced, so there's no need to rebalance.
  7. find median: ( 6+ 5) / 2 = 5.5
from heapq import *

class MedianStream:
    def __init__(self) -> None:
        self.max_heap = []
        self.min_heap = []

    def find_median(self) -> float:
        if len(self.max_heap) == len(self.min_heap):
            return -self.max_heap[0] / 2.0 + self.min_heap[0] / 2.0

        return -self.max_heap[0] / 1.0

    def add_num(self, num: int) -> None:
        if not self.max_heap or -self.max_heap[0] >= num:
            heappush(self.max_heap, -num)
        else:
            heappush(self.min_heap, num)

        self._rebalance()

    def _rebalance(self) -> None:
        if len(self.max_heap) > len(self.min_heap) + 1:
            heappush(self.min_heap, -heappop(self.max_heap))
        elif len(self.max_heap) < len(self.min_heap):
            heappush(self.max_heap, -heappop(self.min_heap))

You may be confused as to why we're taking the negative here: -self.max_heap[0] or here: heappush(self.max_heap, -num). It's because the heapq library only supports a min heap. So by taking the negative we're imitating a max heap.

Here's a few more questions to get your head around it:

https://leetcode.com/problems/sliding-window-median/ (Hard)

Thanks!