Rust でスレッドプールを作る

RPL に書いてあるとおりの内容です。とりあえず完成品から。

use std::{
    sync::{
        mpsc::{self, Receiver},
        Arc, Mutex,
    },
    thread::{self, sleep},
};

pub struct ThreadPool {
    workers: Vec<Worker>,
    sender: Option<mpsc::Sender<Job>>,
}

type Job = Box<dyn FnOnce() + Send + 'static>;

impl ThreadPool {
    pub fn new(size: usize) -> ThreadPool {
        assert!(size > 0);

        let (sender, receiver) = mpsc::channel();

        let receiver = Arc::new(Mutex::new(receiver));

        let mut workers = Vec::with_capacity(size);
        for id in 0..size {
            workers.push(Worker::new(id, Arc::clone(&receiver)));
        }

        ThreadPool {
            workers,
            sender: Some(sender),
        }
    }

    pub fn execute<F>(&self, f: F)
    where
        F: FnOnce() + Send + 'static,
    {
        let job = Box::new(f);
        self.sender.as_ref().unwrap().send(job).unwrap();
    }
}

impl Drop for ThreadPool {
    fn drop(&mut self) {
        drop(self.sender.take());
        for worker in self.workers.drain(..) {
            println!("Shutting down worker {}", worker.id);

            worker.thread.join().unwrap();
        }
    }
}

struct Worker {
    id: usize,
    thread: thread::JoinHandle<()>,
}

impl Worker {
    fn new(id: usize, receiver: Arc<Mutex<Receiver<Job>>>) -> Worker {
        let thread = thread::spawn(move || {
            loop {
                let message = receiver.lock().unwrap().recv();

                match message {
                    Ok(job) => {
                        println!("Worker {id} got a job; executing.");
                        job();
                    }

                    Err(_) => {
                        println!("Worker {id} disconnected; shutting down.");
                        break;
                    }
                }
            }
        });

        Worker { id, thread }
    }
}

ThreadPool 初期化

ThreadPool の初期化から見ていきましょう。mpsc::channel() はとくに何の変哲もない SenderReceiver を返すだけのやつです。

    pub fn new(size: usize) -> ThreadPool {
        assert!(size > 0);

        let (sender, receiver) = mpsc::channel();

        let receiver = Arc::new(Mutex::new(receiver));

        let mut workers = Vec::with_capacity(size);
        for id in 0..size {
            workers.push(Worker::new(id, Arc::clone(&receiver)));
        }

        ThreadPool {
            workers,
            sender: Some(sender),
        }
    }

ところでドキュメントに書いてあるのですが channel で呼び出されるのは非同期チャネルです。sync_channel というのもありまして、これで作られる SyncSender はバッファを事前に割り当てておいて、そのバッファが空くまで送信側を待たせるという挙動の違いがあります。例えば HTTP サーバのスレッドプールに使うなら channel でいいでしょう、送信時にチャネルバッファの容量なんか気にしたくないですし。sync_channel の使い所の例もドキュメントに書いてあって、送信から受信までをアトミックに実行したい場合にバッファサイズを0にして実行するというのはあるみたいです。ほー。

These channels come in two flavors:

An asynchronous, infinitely buffered channel. The channel function will return a (Sender, Receiver) tuple where all sends will be asynchronous (they never block). The channel conceptually has an infinite buffer.

A synchronous, bounded channel. The sync_channel function will return a (SyncSender, Receiver) tuple where the storage for pending messages is a pre-allocated buffer of a fixed size. All sends will be synchronous by blocking until there is buffer space available. Note that a bound of 0 is allowed, causing the channel to become a “rendezvous” channel where each sender atomically hands off a message to a receiver.

Receiver は各 Worker から排他的に参照できるようにしたいので Arc です。

Worker 初期化

Worker の初期化も見ていきます。

struct Worker {
    id: usize,
    thread: thread::JoinHandle<()>,
}

impl Worker {
    fn new(id: usize, receiver: Arc<Mutex<Receiver<Job>>>) -> Worker {
        let thread = thread::spawn(move || {
            loop {
                println!("Worker {id} is working.");
                let lock = receiver.lock().unwrap();
                println!("Worker {id} acquired lock");
                let message = lock.recv();

                match message {
                    Ok(job) => {
                        println!("Worker {id} got a job; executing.");
                        job();
                    }

                    Err(_) => {
                        println!("Worker {id} disconnected; shutting down.");
                        break;
                    }
                }
            }
        });

        Worker { id, thread }
    }
}

初期化した時点で子スレッドが動き出し、receiver のロックを我先にと取りに行きます。が、receiver は Arc なので、ロックを取れるのは Worker のうちいずれか一つです。(という様子が println マクロの表示からわかります) で、このロックを取ったスレッドは、receiver にメッセージが来るまで同期的に待機します。recv は receiver からのメッセージ着信を同期的に待機するからです。

Attempts to wait for a value on this receiver, returning an error if the corresponding channel has hung up.

This function will always block the current thread if there is no data available and it’s possible for more data to be sent (at least one sender still exists). Once a message is sent to the corresponding Sender (or SyncSender), this receiver will wake up and return that message.

execute

receiver へメッセージを届ける(= sender に send する)実際の処理は execute に定義しています。

    pub fn execute<F>(&self, f: F)
    where
        F: FnOnce() + Send + 'static,
    {
        let job = Box::new(f);
        self.sender.as_ref().unwrap().send(job).unwrap();
    }

使い方はこんな感じ。

let pool = ThreadPool::new(10);
pool.execute(|| { /* 実際の処理 */ })

Box<dyn FnOnce() + Send + 'static> という型がだいぶややこしい雰囲気がします。要するに1回以上呼ばれるものであってスレッドをまたいで所有権を転送できる、かつプログラムそれ自体と同じくらい長く生きるもののことを指しているんですが……

まず FnOnceSend も trait です。つまり Box の中に入っているこの型は trait object というやつになります。特に注意すべきこととして FnOnce() + Send とだけ書かれてもコンパイル時に具体的な型が決定できません。そこで dyn と書くとこの trait object が動的ディスパッチされることになるのでコンパイルを通ります。その代わりランタイムコストが余計にかかります。といったことが dyn の説明に書いてあります。

Box でくるまれているのは FnOnceSized じゃないからです。それはそう。

実際スレッドプールにどんな処理が投げられてくるかについて特に仮定を置かないのであればこういうふうに書くことになりそうです。

Send についての詳細はまたの機会に深堀りしようと思いますが、実を言うとこれも RPL に書いてあります。あと nomicon にあるのも見てあるので、このあたり読めばよさそうな雰囲気がする。

感想

Rust 以外の言語も含めて、スレッドプールを自作するのは実はこれがはじめてでした。色々と学ぶところが多かったです。