Cooperative batching

calendar2021-10-06 clock10 min read

Cooperative batching #

A problem I often find myself having in ML (and other domains) is collecting data from worker threads and doing something with it. Sometimes, all you want is one-way communication... but sometimes you want a response. And sometimes (in the case of ML) that data will get massaged and sent somewhere else for processing, and return back at some unspecified time in the future.

Problem description #

If we end at Process, this is a bog-standard Multi-Producer Single-Consumer scenario. However, when we want to return the results from Process to our Producers, things get a bit hairy. You need to track the identity of data, figure out where to send it - or deal with data from each producer in turn, ignoring gains from batching. This is often suboptimal, and can lead to resource contention and starvation.

Solution #

Instead we'd like to find a workflow where we can easily respond to each producer, without a lot of extra boilerplate at callsites, minimal overhead, and no risk of data races or contention. I've found a pattern (if one can call it that) to solve this which has the following important qualities:

  • Cannot deadlock1
  • Can be implemented without very minimal changes to caller code
  • Does as much work as possible in each go
  • Does not require a worker thread

I have no idea what this algorithm could be called, and I've never seen it described in this manner. It borrows from lock-free/CAS-based algorithms, but I've never seen this specific use-case and implementation described. If you have a name for this algorithm, do tell me!

So, moving on to implementation, we're going to need three classes, which we'll define in turn::

 1  class BatchHandle:
 2    """A handle for a participant in a batch."""
 3    ...
 4
 5  class Batch:
 6    """A group of data to be processed together."""
 7    ...
 8
 9  class CooperativeQueue:
10    """A queue-system with cooperative processing of the data, and response synchronization."""
11    ...

Implementing CooperativeQueue #

We'll start at the external API: The CooperativeQueue. This isn't technically a queue, but for discoverability and recognition it makes the most sense. We'll go function-by-function, starting from __init__:

 1class CooperativeQueue:
 2    """A queue-system with cooperative processing of the data, and
 3    response synchronization."""
 4
 5    def __init__(self,
 6                 callback: Callable[[List[Any]], List[Any]],
 7                 execution_lock: Optional[Lock] = None):
 8        """Create a new queue.
 9
10        If provided, the `execution_lock` allows synchronization
11        with external resources when running `callback`.
12        """
13        if execution_lock is None:
14            execution_lock = Lock()
15
16        self._execution_lock = execution_lock
17        """Lock for execution."""
18
19        self._batch_lock = Lock()
20        """Lock for the data-buffer."""
21
22        self._batch = Batch()
23        """The current batch of data being *queued* for execution."""
24
25        self._callback = callback
26        """The function to invoke for the data of each batch."""

Nothing interesting here either - We've got a lock for the "worker" process, a lock for the data we're gathering up, and the actual data. Let's start by defining our put and get methods.

 1    def put(self, data: Any) -> 'BatchHandle':
 2        """Participate in the next batch processed by the queue.
 3
 4        Will return a `BatchHandle` which can be traded for the result
 5        of execution.
 6        """
 7        # Here we take the lock, and receive a "ticket" back for the result.
 8        with self._batch_lock:
 9            handle = self._batch.push(data)
10
11        # If you're "first in line" you also become the leader, which means
12        # you'll execute for yourself and everyone following you.
13        if handle.leader:
14            # Wait until the execution resource becomes available
15            with self._execution_lock:
16                # "Atomically" swap the current data batch
17                with self._batch_lock:
18                    my_batch, self._batch = self._batch, Batch()
19                    # When we leave this scope, a new line will begin
20                    # forming while we execute the callback.
21
22                # We run callback, still holding the execution lock.
23                results = self._callback(my_batch.data())
24
25                # We report the results and mark the batch as finished.
26                my_batch.finish(results)
27
28        return handle
29
30    def get(self, handle: 'BatchHandle') -> Any:
31        """Retrieve the result for participation in `handle`'s batch.
32
33        Will block until the result becomes available.
34        """
35        return handle.retrieve()

I like to combine these two functions in a process function, as I most often care about the result immediately -- but it'd be equally possible to store a handle and read it at a later date.

1    def process(self, data: Any) -> Any:
2        """Participate in the next batch processed by the queue.
3
4        Will return the equivalent of `callback([data])[0]`.
5        """
6        return self.put(data).retrieve()

Since there's multiple locks, I like to study it as a flow graph. The locked scopes become regions in the picture below, and the other code becomes nodes.

The leader in each batch will hit the top path, while every follower will hit the bottom. Tracing out the follower path first, it's clear to see that there's no risk for deadlocks here -- assuming that trigger eventually happens the thread will become unblocked.

The leader path is slightly more interesting. Once it has joined the batch as element 0; it'll immediately queue for the execution lock, waiting for any other execution to finish. Once it acquires the acquisition lock, it'll do a guarded swap of the batch. By taking the lock at this point, there's no race with other potential batch participants -- they'll either have joined already, or have to wait for the next batch. Then, we can execute the callback; while the next batch is being generated by late-comers.

Once the callback ends we finish the batch: this both sets the results to be read, and triggers the event unblocking all other batch members. It'd be perfectly fine to lazily evaluate the handle inside get instead and trigger on first retrieval - but it makes the code more complex for very little gain as I avoid computations that I expect not to use...

Implementing Batch #

Moving on, let's implement the batch. We've defined three interface functions above: push, data, and finish. As we've defined the data lock externally2, these can be very simple:

 1class Batch:
 2    """A group of data to be processed together."""
 3
 4    def __init__(self):
 5        """Create a new empty Batch."""
 6        self._data = []
 7        """All enqueued data in this batch."""
 8
 9        self._results = []
10        """The results, when finished."""
11
12        self._signal = Event()
13        """Synchronization signal for results."""
14
15    def data(self):
16        """Retrieve the data."""
17        return self._data
18
19    def push(self, data) -> 'BatchHandle':
20        """Push new data in exchange for a `BatchHandle`."""
21        index = len(self._data)
22        self._data.append(data)
23        return BatchHandle(self, index, index == 0)
24
25    def finish(self, results):
26        """Finish the batch, signaling all participants to continue."""
27        self._results = results
28        self._signal.set()

The only important part above is the order of setting results and triggering the signal.

Implementing BatchHandle #

Time to implement the last piece of the puzzle. The only public function on the handle retrieve, apart from the constructor.

 1class BatchHandle:
 2    """A handle for a participant in a batch."""
 3
 4    def __init__(self, batch: Batch, index: int, leader: bool):
 5        """Create a new BatchHandle, able to wait for batch completion.."""
 6        self._index = index
 7        self._batch = batch
 8        self.leader = leader
 9
10    def retrieve(self):
11        """Wait for the associated batch to complete."""
12        self._batch._signal.wait()
13        return self._batch._results[self._index]

Again, a very straight-forward implementation - note here the relation of retrieve to finish. All threads will hit the signal.wait() before accessing the batch, so finish has to apply the last mutable operations ever to Batch. Looking at all three pieces together, we can also see that my_batch is the only reference to the Batch outside of the BatchHandles, helping guarantee the invariants.

Usage #

So, finally: let's talk use-cases. There's of course many different methods to do this, but I promised no changes to producer/client-code. Let's start by defining our producer:

1COUNT = 5
2
3
4def producer(worker, idx):
5    for _ in range(COUNT):
6        time.sleep(random.random() / 2.0)
7        idx_p1 = worker.add_1(idx)  # wait, what's this?
8        assert idx_p1 == idx + 1, "incorrect result back from worker!"

We need to solve one mystery before we can run this, though! There's a magic function here that isn't defined, add_1. Let's define a class with that:

 1class Worker(CooperativeQueue):
 2    def __init__(self):
 3        super().__init__(self._do_work)
 4
 5    def _work(self, data):
 6        time.sleep(random.random())
 7        return [d + 1 for d in data]
 8
 9    def _do_work(self, data):
10        print(f'Handling batch of {len(data)} items')
11        return self._work(data)
12
13    def add_1(self, value):
14        return self.process(value)

And now we can actually run this:

1worker = Worker()
2with ThreadPoolExecutor(20) as tpx:
3    tpx.map(producer, [worker] * 20, range(20))
Handling batch of 19 items
Handling batch of 1 items
Handling batch of 19 items
Handling batch of 1 items
Handling batch of 19 items
Handling batch of 1 items
Handling batch of 19 items
Handling batch of 1 items
Handling batch of 17 items
Handling batch of 2 items

We get quite uneven batching, but it's a very simple workload. In my experience; when (a) producer cost is stable and (b) worker cost scales with number of elements, there's a fairly nice spread of data. The below image is is 20 * 20.000 scalars, and there's a nice distribution peaking in the 7-8-9 area. If your worker is significantly slower it'll tend to trend more towards edges; which oftentimes isn't bad - just worth keeping in mind.

Now, I don't generally have queues of scalars going all across the place. It's generally more complex, and different producers might have different amounts of data. So we need the merge and unmerge step from the first graph. Let's start by changing our producer:

1def producer(worker, idx):
2    for _ in range(COUNT):
3        time.sleep(random.random() / 2.0)
4        idx_p1 = worker.add_1([idx for _ in range(idx)])
5        assert idx_p1 == [idx + 1] * idx, "incorrect result back from worker!"

This'll crash our regular code, so we need to implement our fan-in/fan-out operators! I call these mergers and unmergers, and implement them like so:

 1def _make_unmerger(unmerge_data):
 2    def _unmerger(data):
 3        output = []
 4        for start, length in unmerge_data:
 5            output.append(data[start:start+length])
 6        return output
 7
 8    return _unmerger
 9
10
11def _merge(data):
12    merged_data = []
13    unmerge_data = []
14    running_count = 0
15    for idx, d in enumerate(data):
16        merged_data.extend(d)
17        unmerge_data.append((running_count, len(d)))
18        running_count += len(d)
19
20    return merged_data, _make_unmerger(unmerge_data)

Finally, we'll define a new worker using these merge operations -- it's a single overload compared to above:

1class FlatteningWorker(Worker):
2    def _do_work(self, data):
3        data, unmerge = _merge(data)
4        print(f'Handling batch of {len(data)} items')
5        return unmerge(self._work(data))
1worker = FlatteningWorker()
2with ThreadPoolExecutor(20) as tpx:
3    tpx.map(producer, [worker] * 20, range(20))
Handling batch of 141 items
Handling batch of 49 items
Handling batch of 141 items
Handling batch of 49 items
Handling batch of 93 items
Handling batch of 97 items
Handling batch of 93 items
Handling batch of 97 items

Attachments #

A single file containing all code from this blog post.


  1. There's an escape-hatch by using the execution lock. If you call process while holding that lock you'll deadlock. ↩︎

  2. I'll write another blogpost about this some time, but in general I put locks in consumer code unless it's absolutely required for correctness like in CooperativeQueue↩︎