Optimizing a Rust Thread-Pool (Part 1/?)

13 min read Original article ↗

NOTE: All the benchmark and performance considerations in this are extremely specific to the presented scenario, and the corresponding hardware.


Lets build a thread-pool!

Motivation

It began, because i wanted to learn more bout parallelism/concurrency in Rust, and then continued to see how far i could push it (thus the unclear numbering in the title).

What's a Thread-Pool ?

When writing concurrent code, you typically want to be able to execute multiple things in parallel and/or concurrently. A naive way to approach this, would be to spawn a new thread for each task that one wants to parallel/concurrently execute. This might be a perfectly fine solution for smaller amounts, but as the number of tasks grows, this becomes a problem.

For understanding why, let's consider a simple example:

use std::thread;
use std::vec::Vec;

const LARGE_AMOUNT: usize = 100_000_000;

// some simple task
fn do_task(i: usize) {...}

fn main() {
    let mut threads = Vec::with_capacity(LARGE_AMOUNT);

    for i in 0..LARGE_AMOUNT {
        threads.push(thread::spawn(move || do_task(i)));
    }

    // join all threads
    for t in threads {
        t.join().unwrap();
    }
    println!("done!");
}

Now given that we have a LARGE_AMOUNT of tasks, we can estimate a thread-launch to take roughly 10 microseconds (µs), thus resulting in a total of?

100 seconds.

Just for launching the threads, not even considering the actual do_task(i) or extraordinary amount of context-switching require.

The solution is easy: Don't spawn a new thread for each task, instead distribute the tasks to the threads (one might call that a 'pool' of threads).

First Attempt

So what are our requirements then for a thread-pool?

  • We want to prespawn a number of threads, that will run in parallel and execute our tasks.
  • We want to be able to submit these tasks to the pool
  • And we want to be able to wait for the tasks to finish (only once in our examples, for real-world usage we would want something like promise/future based result handling instead of just joining)

So we first need to figure out some type that represents a task, lets call it Task (how original).

All we really need Task to do is to have a function-pointer. In our example we won't support any dynamic value passing or anything (essentially we want the input values to be 'static' after dispatch), so we can use FnOnce to represent our task.

The advantages of this is that its easily implementable for most closures, and also makes our implementation very simple, since it guarantees that the arguments are bound to the task.

With this we can write our task type:

struct Task {
    pub func: Box<dyn FnOnce() -> () + Send>,
}

We already covered FnOnce, we just specify that the function also returns nothing (the () unit-type). Additionally we need to specify that the function implements the Send trait, which guarantees us that func can be shared between threads (at different times though).

Since we don't want to specify the argument types, we can use the dyn to create a trait-object, which we need to warp into a Box<...> pointer since we can't know the size of the trait-object at compile-time. (I'm not sure my explanation is quite accurate, but it works enough for me to work with for now, feel free to correct me if I'm majorly off).


With the base element Task now defined, we can go on and specify the thread-pool itself.

pub struct Pool {
    threads: Vec<JoinHandle<()>>,
    queue: Arc<Mutex<Vec<Task>>>,
    done: Arc<AtomicBool>,
}

The JoinHandle type is what thread::spawn returns, so we store these to reference them later on. The queue is a simple Vec of Tasks, which we will use to store the tasks that we want to execute and distribute to the threads. (This needs to be Arc<Mutex<...>> since we want to share it safely between threads).

Finally we don't want to shutoff the threads when we're done, so we keep track of the done flag, which we will set to true when we want to actually shutoff the threads. (Also needs Arc<...> for sharing between threads).

The creation of a new thread-pool is pretty straight forward, the only interesting part is the initialization of the threads themselves:

impl Pool {
    pub fn new(amount: usize) -> Self {
        let mut threads = Vec::with_capacity(amount); // pre-allocate the threads
        let queue = Arc::new(Mutex::new(Vec::with_capacity(amount))); 
        let done = Arc::new(AtomicBool::new(false));

        for _ in 0..amount {
            // clone the queue and done flag for each thread
            let queue = queue.clone(); 
            let done = done.clone(); 

            // Spawn the thread with a worker-function and store the join-handle to the vector.
            threads.push(thread::spawn(move || {
                // used for exponential backoff when waiting for a task
                let mut sleep_counter = 1; 

                loop { // loop forever
                    // check if we done flag is set
                    if done.load(ordering::relaxed) {
                        // if so, break out of the loop and terminate thread
                        break;
                    }

                    // try and pop a task from the queue
                    let task = {
                        let mut queue = queue.lock().unwrap();
                        queue.pop()
                    };

                    // if we got a task, execute it
                    if let some(task) = task { 
                        (task.func)();
                    } else {
                        // no task, sleep a bit with exponential backoff (so we dont block the mutex too much)
                        thread::sleep(std::time::duration::from_nanos(1 << sleep_counter));
                        sleep_counter += 1;
                    }
                    // continue looping
                }
            }));

        }

        // return the pool
        Self {
            threads,
            queue,
            done,
        }
    }
}

This is a large code-block, but it's also almost the entirety of the complicated part of the implementation.

Breaking it down:

  • We setup our state variables: threads, queue, done
  • We create a configurable amount of threads, which we all spawn via thread::spawn and store the JoinHandle
    • These threads all run the same 'worker-function' which basically just loops forever and waits for a task to be available (or until done is set)
    • If a task is available, it is executed and removed from the queue
    • If no task is available, we sleep for a bit and try again (with exponential back off as to not block the queue mutex too much).

(*Note: While writing this i noticed that it might've been helpful to try other waiting strategies, but since we'll get rid of this mechanism soon anyway, i chose to keep it simple.*)

For now submitting tasks to the pool, all we need to do is push the Task to the queue, and one of the worker-threads will pick it up and execute it.

impl Pool {
    pub fn submit<F: FnOnce() -> () + Send + 'static>(&self, func: F) {
        let mut queue = self.queue.lock().unwrap();
        queue.push(Task {
            func: Box::new(func),
        });
    }
}

This is pretty straight forward, we just lock the queue and push the task to the queue. (The lifetime of the func parameter is 'static, one would probably want this to be the same lifetime as &self but for our purposes this is fine).

Joining is also pretty easy, we first wait till all threads are done (queue is empty), then set done to true and join all threads.

impl Pool {
    // waits for all tasks to finish, and then joins all threads
    pub fn join_all(self) {
        // wait for all tasks to finish
        let mut sleep_counter = 1;
        while !self.queue.lock().unwrap().is_empty() {
            thread::sleep(std::time::Duration::from_nanos(1 << sleep_counter));
            sleep_counter += 1;
        }

        self.done.store(true, Ordering::Relaxed);
        for thread in self.threads.into_iter() {
            thread.join().unwrap();
        }
    }
}

With this we can now create a simple example to measure the performance of our thread-pool (essentially what we had above):

let pool = v1_simple::Pool::new(pool_size);

for i in 0..TASK_COUNT {
    pool.submit(move || {
        collatz(i);
    });
}

pool.join_all();

Benchmarking

For the actual workload, i went with simply computing the Collatz Conjecture for the given number (specifically the amount of steps required to converge to 1). This was a little bit contrived, but it should work fine as an example, it's compute-heavy but with non-trivial run-time characteristics (so not just a sleep(x)).

pub const PRINT_INTERVAL: u64 = 333333;

pub fn collatz(mut n: u64) {
    if n == 0 {
        return;
    }
    let original = n;
    let mut steps: u32 = 0;
    while n != 1 {
        if n % 2 == 0 {
            n /= 2;
        } else {
            n = 3 * n + 1;
        }
        steps += 1;
    }
    // print some subresults, so we can see the progress and the compiler doesn't optimize the computation away
    if original % PRINT_INTERVAL == 0 {
        print!("{original} took {steps} steps to converge\n");
    }
}

(We only print some of the values, to make sure we stay compute-bound & don’t choke on I/O)

With this we can now profile our implementation over varying pool-sizes (while keeping the workload constant).

I decided to compare my implementations against a C++ thread-pool implementation that I worked with previously, BS::thread_pool (https://github.com/bshoshany/thread-pool). This is a great header-only library, I can only recommend it if you need to work with C++, but that’s not what we are here for.

The C++ code for the workload looks essentially the same, which we execute via the pool.submit_loop(…) feature.

First results?

So how does our naive mutex implementation compete with the C++ implementation?

Not to well... not only are we 2x slower for a single worker-thread, we don’t actually see any speedup with more threads. So what is going on ?

Second attempt

When we profile our program, we see that most of the time is actually spent inside Pool::submit(…) which immediately should make us suspicious. We want the thread-pool to accelerate our workload, not slow it down. So, what can we do ? For this we can take ‘inspiration’ (def. not stealing) from BS::thread_pool where we use pool.submit_loop(…). This differs from our implementation in two important ways:

  1. It submits the entirety of the loop into the task-queue, not only a single instance. This requires a single mutex-lock, not N-times.
  2. It chunks the loop into separate batches, instead of N-tasks (we’ll get to this later).

So starting with the first point, we can write an equivalent function in Rust, that also takes an iterator (alongside func):

// publishes a new task per item in the iterator
pub fn submit_iter<F: Fn(T) + Send + Sync + 'static, T: Send + 'static>(
    &self,
    func: Arc<F>,
    iter: impl IntoIterator<Item = T>,
) {
    let mut queue = self.queue.lock().unwrap();
    for i in iter {
        let func = Arc::clone(&func);
        queue.push(Task {
            func: Box::new(move || func(i)),
        });
    }
}

We can see that with this, that we essentially do the same as in our first Pool::submit() impl. But now we add the for-loop over the given iterator INSIDE the mutex-lock area.

We again benchmark this:

Still quite disappointing, even worse at times. So on to the next point:

Third attempt

Right now each Task we submit is a single call to collatz(i). So we have N-tasks, each only a single item/call. We can do better, by chunking the total range of the passed in iter into batches.

Before we had tasks like: [1, 2, 3, 4, 5, 6, …]

But now we can reduce the total amount of tasks by batching these together: [[1, 2, 3], [4, 5, 6] , …].

With this, the worker-threads need to only lock the queue-lock once for each task, covering multiple iterations at once. This is actually quite easy to do for our implementation, by bringing in the itertools-crate and using the .chunks() trait:

pub fn submit_iter<F: Fn(T) + Send + Sync + 'static, T: Send + 'static>(
    &self,
    func: Arc<F>,
    iter: impl IntoIterator<Item = T>,
    chunk_size: usize,
) {
    let mut queue = self.queue.lock().unwrap();
    for chunk in &iter.into_iter().chunks(chunk_size) { // chunk iter
        let func = Arc::clone(&func);
        let batch: Vec<T> = chunk.collect();
        queue.push(Task {
            func: Box::new(move || {
                for item in batch {
                    func(item);
                }
            }),
        });
    }
}

(We could’ve done the chunking ourselves, but this makes the code a lot cleaner)

Finally we are getting somewhere!

Might be obvious in retrospect, but minimizing the amount of tasks reduces the overhead quite drastically.

So now what is left ?

Looking at the code we have a quite ugly part when it comes to trying to acquire the lock, specifically:

// if we got a task, execute it
if let some(task) = task { 
    // ...
} else {
    // no task, sleep a bit with exponential backoff (so we dont block the mutex too much)
    thread::sleep(std::time::duration::from_nanos(1 << sleep_counter));
    sleep_counter += 1;
}

(The same mechanism is also present in Pool::join_all()) We are potentially sleeping quite along time before trying to reacquire the lock, the original reason we have this is to prevent CPU's just spamming the mutex. But this can cause us to sleep longer than required, wasting precious time.

But maybe we can do something smarter to get closer to that gray line ?

Condvars

Condition-Variables (in Rust accessible via use std::sync::Condvar) are a mechanism that allows us to put a thread to sleep, and then wake it up without any active checking on that threads part. Essentially the thread gets woken up automatically.

We can use this for multiple aspects of our implementation:

  1. (In the worker-thread) When waiting to acquire the queue-lock to check for new work.
  2. (In the main-thread) When waiting for the worker-threads to be done with their work.

For the first part, we can simply add a cond-var that notifies the listeners that new work is available (instead of the thread::sleep(...)):

pub fn submit_iter<F: Fn(T) + Send + Sync + 'static, T: Send + 'static>(
    &self,
    func: Arc<F>,
    iter: impl IntoIterator<Item = T>,
    chunk_size: usize,
) {
    let mut queue = self.queue.lock().unwrap();
    for chunk in &iter.into_iter().chunks(chunk_size) {
      // same content
    }
    // Notify waiting threads that there is work ready
    self.work_condvar.notify_all();
}

And we then need to change our worker-loop to use this work_condvar instead of sleeping:

let work_condvar = Arc::new(Condvar::new());

// the worker-loop:
loop {
    let mut guard = queue.lock().unwrap();
  
    while guard.is_empty() && !done.load(Ordering::Acquire) {
        // if we can't get the lock, we register our guard to be tracked by the work_condvar
        guard = work_condvar.wait(guard).unwrap();
        // once we get notified that work is available, we try to acquire the guard again
    }
   // rest  of the loop is mostly the same
   // ..
}

This removes over-eager waiting, or extra work done to check mutex-locking.


The second part we can use condvars for is the Pool::join_all(). For this we introduce a second condvar for notifying the main-thread that all threads are done. We also need to keep track of the amount of

pub fn join_all(self) {
    // Signal shutdown
    self.done.store(true, Ordering::Release);
    self.work_condvar.notify_all(); // to make sure that we don't have any worker-threads stuck at their while-loop

    // wait for all work to complete (queue empty + no executing tasks)
    let mut guard = self.queue.lock().unwrap();
  
    while !guard.is_empty() || self.executing.load(Ordering::Acquire) > 0 {
        // main thread keeps waiting for completion signal
        guard = self.completion_condvar.wait(guard).unwrap();
    }
    drop(guard);

    // finally, join the threads
    for thread in self.threads.into_iter() {
        thread.join().unwrap();
    }
}

With this, we have no more explicit sleeping, and all the wake-up calls are handled by the OS. What does this do to our performance? Let’s see:

(For this we only plot our last two implementations, and the C++ reference)

Success! But we can do even better, in the future i hope we'll get to explore some more optimizations, and maybe even play around with some lock-free data-structures. But that's for another day, right now we have a nice baseline to work with.