#![allow(unused)]
use std::sync::{mpsc, Arc, Mutex};
use std::thread;
// 定义消息类型,可以是新任务或终止信号
enum Message {
NewJob(Job),
Terminate,
}
// 定义线程池结构体
pub struct ThreadPool {
workers: Vec<Worker>,
// sender: mpsc::Sender<Job>,
sender: mpsc::Sender<Message>,
}
// 定义任务类型,可以是任何实现了FnOnce trait的闭包
type Job = Box<dyn FnOnce() + Send + 'static>;
impl ThreadPool {
// 创建线程池,参数为线程数量
pub fn new(size: usize) -> ThreadPool {
assert!(size > 0);
// 创建一个通道,用于发送任务
let (sender, receiver) = mpsc::channel();
// 将通道包装成Arc<Mutex>,以便多个线程共享
let receiver = Arc::new(Mutex::new(receiver));
let mut workers = Vec::with_capacity(size);
// 创建指定数量的工作线程
for id in 0..size {
workers.push(Worker::new(id, Arc::clone(&receiver)));
}
ThreadPool { workers, sender }
}
// 执行任务,参数为任务闭包
pub fn execute<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
// 将任务包装成Box<dyn FnOnce() + Send + 'static>,并发送到通道
let job = Box::new(f);
// self.sender.send(job).unwrap();
self.sender.send(Message::NewJob(job)).unwrap();
}
}
// 实现Drop trait,在线程池被销毁时执行清理操作
impl Drop for ThreadPool {
fn drop(&mut self) {
// 向每个工作线程发送终止信号
for _ in &mut self.workers {
self.sender.send(Message::Terminate).unwrap();
}
println!("Shutting down all workers.");
// 等待每个工作线程结束
for worker in &mut self.workers {
println!("Shutting down worker {}", worker.id);
if let Some(thread) = worker.thread.take() {
thread.join().unwrap();
}
}
}
}
// 定义工作线程结构体
struct Worker {
id: usize,
// thread: thread::JoinHandle<()>,
thread: Option<thread::JoinHandle<()>>,
}
impl Worker {
// 创建工作线程,参数为线程ID和通道
fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Message>>>) -> Worker {
// 创建一个新线程,循环接收任务并执行
let thread = thread::spawn(move || loop {
// let job = receiver.lock().unwrap().recv().unwrap();
// println!("Worker {} got a job; executing.", id);
// job();
// 接收任务
let message = receiver.lock().unwrap().recv().unwrap();
// 根据任务类型执行相应操作
match message {
Message::NewJob(job) => {
println!("Worker {} executing job.", id);
job();
}
Message::Terminate => {
println!("Worker {} terminated.", id);
break;
}
}
});
Worker {
id,
thread: Some(thread),
}
}
}