Files
ruin/lib/runtime/src/time.rs

176 lines
4.5 KiB
Rust

//! Runtime time primitives.
use std::cell::{Cell, RefCell};
use std::fmt;
use std::future::{Future, poll_fn};
use std::io;
use std::pin::Pin;
use std::rc::Rc;
use std::task::Waker;
use std::task::{Context, Poll};
use std::time::Duration;
use crate::{clear_timeout, set_timeout};
pub struct Sleep {
delay: Option<Duration>,
state: Option<Rc<SleepState>>,
handle: Option<crate::TimeoutHandle>,
completed: bool,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct Elapsed;
pub fn sleep(duration: Duration) -> Sleep {
Sleep {
delay: Some(duration),
state: None,
handle: None,
completed: false,
}
}
pub async fn timeout<F>(duration: Duration, future: F) -> Result<F::Output, Elapsed>
where
F: Future,
{
let mut future = std::pin::pin!(future);
let mut sleeper = std::pin::pin!(sleep(duration));
poll_fn(|cx| {
if let Poll::Ready(output) = future.as_mut().poll(cx) {
return Poll::Ready(Ok(output));
}
if let Poll::Ready(()) = sleeper.as_mut().poll(cx) {
return Poll::Ready(Err(Elapsed));
}
Poll::Pending
})
.await
}
pub fn timeout_error(action: &'static str) -> io::Error {
io::Error::new(io::ErrorKind::TimedOut, format!("{action} timed out"))
}
impl Future for Sleep {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.completed {
return Poll::Ready(());
}
if self.state.is_none() {
let delay = self.delay.take().unwrap_or(Duration::ZERO);
let state = Rc::new(SleepState::default());
let state_for_callback = Rc::clone(&state);
let timeout_handle = set_timeout(delay, move || state_for_callback.complete());
self.state = Some(state);
self.handle = Some(timeout_handle);
}
let state = self
.state
.as_ref()
.expect("sleep state should be initialized");
if state.ready.get() {
self.completed = true;
self.state = None;
self.handle = None;
Poll::Ready(())
} else {
*state.waker.borrow_mut() = Some(cx.waker().clone());
if state.ready.get() {
self.completed = true;
self.state = None;
self.handle = None;
Poll::Ready(())
} else {
Poll::Pending
}
}
}
}
impl Drop for Sleep {
fn drop(&mut self) {
if self.completed {
return;
}
if let Some(handle) = self.handle.take() {
clear_timeout(&handle);
}
}
}
#[derive(Default)]
struct SleepState {
ready: Cell<bool>,
waker: RefCell<Option<Waker>>,
}
impl SleepState {
fn complete(&self) {
self.ready.set(true);
if let Some(waker) = self.waker.borrow_mut().take() {
waker.wake();
}
}
}
impl fmt::Display for Elapsed {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("deadline elapsed")
}
}
impl std::error::Error for Elapsed {}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Mutex};
use std::time::Duration;
use crate::{queue_future, queue_task, run};
use super::{sleep, timeout};
#[test]
fn sleep_and_timeout_work() {
let log = std::thread::spawn(|| {
let log = Arc::new(Mutex::new(Vec::new()));
let log_for_task = Arc::clone(&log);
queue_task(move || {
let log_for_task = Arc::clone(&log_for_task);
queue_future(async move {
log_for_task.lock().unwrap().push("started");
sleep(Duration::from_millis(5)).await;
log_for_task.lock().unwrap().push("slept");
let result = timeout(Duration::from_millis(5), async {
sleep(Duration::from_millis(20)).await;
42usize
})
.await;
assert!(result.is_err(), "timeout should fire first");
log_for_task.lock().unwrap().push("timed out");
});
});
run();
let log = log.lock().unwrap();
log.clone()
})
.join()
.expect("time test thread should join successfully");
assert_eq!(log.as_slice(), ["started", "slept", "timed out"]);
}
}