176 lines
4.5 KiB
Rust
176 lines
4.5 KiB
Rust
//! Runtime time primitives.
|
|
|
|
use std::cell::{Cell, RefCell};
|
|
use std::fmt;
|
|
use std::future::{Future, poll_fn};
|
|
use std::io;
|
|
use std::pin::Pin;
|
|
use std::rc::Rc;
|
|
use std::task::Waker;
|
|
use std::task::{Context, Poll};
|
|
use std::time::Duration;
|
|
|
|
use crate::{clear_timeout, set_timeout};
|
|
|
|
pub struct Sleep {
|
|
delay: Option<Duration>,
|
|
state: Option<Rc<SleepState>>,
|
|
handle: Option<crate::TimeoutHandle>,
|
|
completed: bool,
|
|
}
|
|
|
|
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
|
pub struct Elapsed;
|
|
|
|
pub fn sleep(duration: Duration) -> Sleep {
|
|
Sleep {
|
|
delay: Some(duration),
|
|
state: None,
|
|
handle: None,
|
|
completed: false,
|
|
}
|
|
}
|
|
|
|
pub async fn timeout<F>(duration: Duration, future: F) -> Result<F::Output, Elapsed>
|
|
where
|
|
F: Future,
|
|
{
|
|
let mut future = std::pin::pin!(future);
|
|
let mut sleeper = std::pin::pin!(sleep(duration));
|
|
|
|
poll_fn(|cx| {
|
|
if let Poll::Ready(output) = future.as_mut().poll(cx) {
|
|
return Poll::Ready(Ok(output));
|
|
}
|
|
|
|
if let Poll::Ready(()) = sleeper.as_mut().poll(cx) {
|
|
return Poll::Ready(Err(Elapsed));
|
|
}
|
|
|
|
Poll::Pending
|
|
})
|
|
.await
|
|
}
|
|
|
|
pub fn timeout_error(action: &'static str) -> io::Error {
|
|
io::Error::new(io::ErrorKind::TimedOut, format!("{action} timed out"))
|
|
}
|
|
|
|
impl Future for Sleep {
|
|
type Output = ();
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
if self.completed {
|
|
return Poll::Ready(());
|
|
}
|
|
|
|
if self.state.is_none() {
|
|
let delay = self.delay.take().unwrap_or(Duration::ZERO);
|
|
let state = Rc::new(SleepState::default());
|
|
let state_for_callback = Rc::clone(&state);
|
|
let timeout_handle = set_timeout(delay, move || state_for_callback.complete());
|
|
self.state = Some(state);
|
|
self.handle = Some(timeout_handle);
|
|
}
|
|
|
|
let state = self
|
|
.state
|
|
.as_ref()
|
|
.expect("sleep state should be initialized");
|
|
if state.ready.get() {
|
|
self.completed = true;
|
|
self.state = None;
|
|
self.handle = None;
|
|
Poll::Ready(())
|
|
} else {
|
|
*state.waker.borrow_mut() = Some(cx.waker().clone());
|
|
if state.ready.get() {
|
|
self.completed = true;
|
|
self.state = None;
|
|
self.handle = None;
|
|
Poll::Ready(())
|
|
} else {
|
|
Poll::Pending
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Drop for Sleep {
|
|
fn drop(&mut self) {
|
|
if self.completed {
|
|
return;
|
|
}
|
|
|
|
if let Some(handle) = self.handle.take() {
|
|
clear_timeout(&handle);
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Default)]
|
|
struct SleepState {
|
|
ready: Cell<bool>,
|
|
waker: RefCell<Option<Waker>>,
|
|
}
|
|
|
|
impl SleepState {
|
|
fn complete(&self) {
|
|
self.ready.set(true);
|
|
if let Some(waker) = self.waker.borrow_mut().take() {
|
|
waker.wake();
|
|
}
|
|
}
|
|
}
|
|
|
|
impl fmt::Display for Elapsed {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
f.write_str("deadline elapsed")
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for Elapsed {}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use std::sync::{Arc, Mutex};
|
|
use std::time::Duration;
|
|
|
|
use crate::{queue_future, queue_task, run};
|
|
|
|
use super::{sleep, timeout};
|
|
|
|
#[test]
|
|
fn sleep_and_timeout_work() {
|
|
let log = std::thread::spawn(|| {
|
|
let log = Arc::new(Mutex::new(Vec::new()));
|
|
let log_for_task = Arc::clone(&log);
|
|
|
|
queue_task(move || {
|
|
let log_for_task = Arc::clone(&log_for_task);
|
|
queue_future(async move {
|
|
log_for_task.lock().unwrap().push("started");
|
|
sleep(Duration::from_millis(5)).await;
|
|
log_for_task.lock().unwrap().push("slept");
|
|
|
|
let result = timeout(Duration::from_millis(5), async {
|
|
sleep(Duration::from_millis(20)).await;
|
|
42usize
|
|
})
|
|
.await;
|
|
assert!(result.is_err(), "timeout should fire first");
|
|
log_for_task.lock().unwrap().push("timed out");
|
|
});
|
|
});
|
|
run();
|
|
|
|
let log = log.lock().unwrap();
|
|
log.clone()
|
|
})
|
|
.join()
|
|
.expect("time test thread should join successfully");
|
|
|
|
assert_eq!(log.as_slice(), ["started", "slept", "timed out"]);
|
|
}
|
|
}
|