use futures::Future;
use log::trace;
use tokio::runtime::Runtime;
use tokio::task::JoinError;
pub struct TaskManager {
    runtime: Runtime,
}
impl TaskManager {
    pub fn new(thread_count: usize) -> Self {
        let runtime = tokio::runtime::Builder::new_multi_thread()
            .enable_all()
            .max_blocking_threads(thread_count)
            .build()
            .expect("tokio rt failed");
        TaskManager { runtime }
    }
    pub fn add_task<T: FnOnce() + Send + 'static>(&self, task: T) {
        trace!("adding a task");
        self.runtime.spawn_blocking(task);
    }
    pub fn add_task_async<R: Send + 'static, T: Future<Output = R> + Send + 'static>(
        &self,
        task: T,
    ) -> impl Future<Output = Result<R, JoinError>> {
        self.runtime.spawn(task)
    }
    #[allow(dead_code)]
    pub fn run_task_blocking<R: Send + 'static, T: FnOnce() -> R + Send + 'static>(
        &self,
        task: T,
    ) -> R {
        trace!("adding a sync task from thread {}", thread_id::get());
        let join_handle = self.runtime.spawn_blocking(task);
        self.runtime.block_on(join_handle).expect("task failed")
    }
}
#[cfg(test)]
mod tests {
    use crate::task_manager::TaskManager;
    use log::trace;
    use std::thread;
    use std::time::Duration;
    #[test]
    fn test() {
        trace!("testing");
        let tm = TaskManager::new(1);
        for _x in 0..5 {
            tm.add_task(|| {
                thread::sleep(Duration::from_secs(1));
            })
        }
        let s = tm.run_task_blocking(|| {
            thread::sleep(Duration::from_secs(1));
            "res"
        });
        assert_eq!(s, "res");
        for _x in 0..10 {
            let s = tm.run_task_blocking(|| "res");
            assert_eq!(s, "res");
        }
    }
}