hirofa_utils/
task_manager.rs

1use futures::Future;
2use log::trace;
3use tokio::runtime::Runtime;
4use tokio::task::JoinError;
5
6pub struct TaskManager {
7    runtime: Runtime,
8}
9
10impl TaskManager {
11    pub fn new(thread_count: usize) -> Self {
12        // start threads
13
14        let runtime = tokio::runtime::Builder::new_multi_thread()
15            .enable_all()
16            .max_blocking_threads(thread_count)
17            .build()
18            .expect("tokio rt failed");
19
20        TaskManager { runtime }
21    }
22
23    pub fn add_task<T: FnOnce() + Send + 'static>(&self, task: T) {
24        trace!("adding a task");
25        self.runtime.spawn_blocking(task);
26    }
27
28    /// start an async task
29    /// # Example
30    /// ```rust
31    /// use hirofa_utils::task_manager::TaskManager;
32    /// let tm = TaskManager::new(2);
33    /// let task = async {
34    ///     println!("foo");
35    /// };
36    /// tm.add_task_async(task);
37    /// ```
38    pub fn add_task_async<R: Send + 'static, T: Future<Output = R> + Send + 'static>(
39        &self,
40        task: T,
41    ) -> impl Future<Output = Result<R, JoinError>> {
42        self.runtime.spawn(task)
43    }
44
45    #[allow(dead_code)]
46    pub fn run_task_blocking<R: Send + 'static, T: FnOnce() -> R + Send + 'static>(
47        &self,
48        task: T,
49    ) -> R {
50        trace!("adding a sync task from thread {}", thread_id::get());
51        // check if the current thread is not a worker thread, because that would be bad
52        let join_handle = self.runtime.spawn_blocking(task);
53        self.runtime.block_on(join_handle).expect("task failed")
54    }
55}
56
57#[cfg(test)]
58mod tests {
59    use crate::task_manager::TaskManager;
60    use log::trace;
61    use std::thread;
62    use std::time::Duration;
63
64    #[test]
65    fn test() {
66        trace!("testing");
67
68        let tm = TaskManager::new(1);
69        for _x in 0..5 {
70            tm.add_task(|| {
71                thread::sleep(Duration::from_secs(1));
72            })
73        }
74
75        let s = tm.run_task_blocking(|| {
76            thread::sleep(Duration::from_secs(1));
77            "res"
78        });
79
80        assert_eq!(s, "res");
81
82        for _x in 0..10 {
83            let s = tm.run_task_blocking(|| "res");
84
85            assert_eq!(s, "res");
86        }
87    }
88}