1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
use futures::Future;
use log::trace;
use tokio::runtime::Runtime;
use tokio::task::JoinError;

pub struct TaskManager {
    runtime: Runtime,
}

impl TaskManager {
    pub fn new(thread_count: usize) -> Self {
        // start threads

        let runtime = tokio::runtime::Builder::new_multi_thread()
            .enable_all()
            .max_blocking_threads(thread_count)
            .build()
            .expect("tokio rt failed");

        TaskManager { runtime }
    }

    pub fn add_task<T: FnOnce() + Send + 'static>(&self, task: T) {
        trace!("adding a task");
        self.runtime.spawn_blocking(task);
    }

    /// start an async task
    /// # Example
    /// ```rust
    /// use hirofa_utils::task_manager::TaskManager;
    /// let tm = TaskManager::new(2);
    /// let task = async {
    ///     println!("foo");
    /// };
    /// tm.add_task_async(task);
    /// ```
    pub fn add_task_async<R: Send + 'static, T: Future<Output = R> + Send + 'static>(
        &self,
        task: T,
    ) -> impl Future<Output = Result<R, JoinError>> {
        self.runtime.spawn(task)
    }

    #[allow(dead_code)]
    pub fn run_task_blocking<R: Send + 'static, T: FnOnce() -> R + Send + 'static>(
        &self,
        task: T,
    ) -> R {
        trace!("adding a sync task from thread {}", thread_id::get());
        // check if the current thread is not a worker thread, because that would be bad
        let join_handle = self.runtime.spawn_blocking(task);
        self.runtime.block_on(join_handle).expect("task failed")
    }
}

#[cfg(test)]
mod tests {
    use crate::task_manager::TaskManager;
    use log::trace;
    use std::thread;
    use std::time::Duration;

    #[test]
    fn test() {
        trace!("testing");

        let tm = TaskManager::new(1);
        for _x in 0..5 {
            tm.add_task(|| {
                thread::sleep(Duration::from_secs(1));
            })
        }

        let s = tm.run_task_blocking(|| {
            thread::sleep(Duration::from_secs(1));
            "res"
        });

        assert_eq!(s, "res");

        for _x in 0..10 {
            let s = tm.run_task_blocking(|| "res");

            assert_eq!(s, "res");
        }
    }
}