//! Runtime time primitives. use std::cell::{Cell, RefCell}; use std::fmt; use std::future::{Future, poll_fn}; use std::io; use std::pin::Pin; use std::rc::Rc; use std::task::Waker; use std::task::{Context, Poll}; use std::time::Duration; use crate::{clear_timeout, set_timeout}; pub struct Sleep { delay: Option, state: Option>, handle: Option, completed: bool, } #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct Elapsed; pub fn sleep(duration: Duration) -> Sleep { Sleep { delay: Some(duration), state: None, handle: None, completed: false, } } pub async fn timeout(duration: Duration, future: F) -> Result where F: Future, { let mut future = std::pin::pin!(future); let mut sleeper = std::pin::pin!(sleep(duration)); poll_fn(|cx| { if let Poll::Ready(output) = future.as_mut().poll(cx) { return Poll::Ready(Ok(output)); } if let Poll::Ready(()) = sleeper.as_mut().poll(cx) { return Poll::Ready(Err(Elapsed)); } Poll::Pending }) .await } pub fn timeout_error(action: &'static str) -> io::Error { io::Error::new(io::ErrorKind::TimedOut, format!("{action} timed out")) } impl Future for Sleep { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if self.completed { return Poll::Ready(()); } if self.state.is_none() { let delay = self.delay.take().unwrap_or(Duration::ZERO); let state = Rc::new(SleepState::default()); let state_for_callback = Rc::clone(&state); let timeout_handle = set_timeout(delay, move || state_for_callback.complete()); self.state = Some(state); self.handle = Some(timeout_handle); } let state = self .state .as_ref() .expect("sleep state should be initialized"); if state.ready.get() { self.completed = true; self.state = None; self.handle = None; Poll::Ready(()) } else { *state.waker.borrow_mut() = Some(cx.waker().clone()); if state.ready.get() { self.completed = true; self.state = None; self.handle = None; Poll::Ready(()) } else { Poll::Pending } } } } impl Drop for Sleep { fn drop(&mut self) { if self.completed { return; } if let Some(handle) = self.handle.take() { clear_timeout(&handle); } } } #[derive(Default)] struct SleepState { ready: Cell, waker: RefCell>, } impl SleepState { fn complete(&self) { self.ready.set(true); if let Some(waker) = self.waker.borrow_mut().take() { waker.wake(); } } } impl fmt::Display for Elapsed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("deadline elapsed") } } impl std::error::Error for Elapsed {} #[cfg(test)] mod tests { use std::sync::{Arc, Mutex}; use std::time::Duration; use crate::{queue_future, queue_task, run}; use super::{sleep, timeout}; #[test] fn sleep_and_timeout_work() { let log = std::thread::spawn(|| { let log = Arc::new(Mutex::new(Vec::new())); let log_for_task = Arc::clone(&log); queue_task(move || { let log_for_task = Arc::clone(&log_for_task); queue_future(async move { log_for_task.lock().unwrap().push("started"); sleep(Duration::from_millis(5)).await; log_for_task.lock().unwrap().push("slept"); let result = timeout(Duration::from_millis(5), async { sleep(Duration::from_millis(20)).await; 42usize }) .await; assert!(result.is_err(), "timeout should fire first"); log_for_task.lock().unwrap().push("timed out"); }); }); run(); let log = log.lock().unwrap(); log.clone() }) .join() .expect("time test thread should join successfully"); assert_eq!(log.as_slice(), ["started", "slept", "timed out"]); } }