//! Public runtime driver primitives. use std::cell::Cell; use std::cell::RefCell; use std::collections::HashMap; use std::io; use std::os::fd::RawFd; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use super::uring::{IORING_OP_ASYNC_CANCEL, IoUring, IoUringCqe, IoUringSqe}; use crate::trace_targets; const WAKE_TARGET_TOKEN: u64 = 1; const TOKEN_KIND_SHIFT: u64 = 56; const TOKEN_KIND_MASK: u64 = 0xff << TOKEN_KIND_SHIFT; #[derive(Clone, Copy, Debug, Eq, PartialEq)] #[repr(u8)] enum CompletionKind { Timer = 1, TimerRemove = 2, NotifySend = 3, Operation = 4, OperationCancel = 5, } type CompletionHandler = Box; struct NotifierInner { ring_fd: RawFd, closed: AtomicBool, } impl NotifierInner { fn notify(&self) -> io::Result<()> { #[cfg(debug_assertions)] tracing::trace!( target: trace_targets::DRIVER, event = "notify", ring_fd = self.ring_fd, "sending cross-thread driver wake" ); if self.closed.load(Ordering::Acquire) { return Err(io::Error::new( io::ErrorKind::BrokenPipe, "target runtime ring is closed", )); } IoUring::with_submitter(|ring| { ring.submit_msg_ring( self.ring_fd, WAKE_TARGET_TOKEN, 1, make_token(CompletionKind::NotifySend, 0), ) }) } } #[derive(Clone)] /// Cross-thread notifier for a runtime thread's driver. pub struct ThreadNotifier { inner: Arc, } impl ThreadNotifier { /// Sends a wake notification to the target runtime thread. pub fn notify(&self) -> io::Result<()> { self.inner.notify() } } #[derive(Debug, Default, Clone, Copy, Eq, PartialEq)] /// Readiness information returned by [`Driver::poll`]. pub struct ReadyEvents { /// One or more timer expirations are pending. pub timer: bool, /// One or more cross-thread wake notifications are pending. pub wake: bool, } /// Low-level Linux runtime driver backed by `io_uring`. pub struct Driver { ring: IoUring, notifier: Arc, next_token: Cell, active_timer_token: Cell>, pending_wakes: Cell, pending_timers: Cell, completions: RefCell>, } /// Creates a new driver and its paired [`ThreadNotifier`]. pub fn create() -> io::Result<(Driver, ThreadNotifier)> { create_driver() } /// Creates a new driver and its paired [`ThreadNotifier`]. /// /// This is identical to [`create`] and exists as a more explicit name for callers that want to /// emphasize driver construction. pub fn create_driver() -> io::Result<(Driver, ThreadNotifier)> { let ring = IoUring::new(64)?; tracing::debug!( target: trace_targets::DRIVER, event = "create_driver", ring_fd = ring.ring_fd(), "created runtime driver" ); let notifier = Arc::new(NotifierInner { ring_fd: ring.ring_fd(), closed: AtomicBool::new(false), }); Ok(( Driver { ring, notifier: Arc::clone(¬ifier), next_token: Cell::new(1), active_timer_token: Cell::new(None), pending_wakes: Cell::new(0), pending_timers: Cell::new(0), completions: RefCell::new(HashMap::new()), }, ThreadNotifier { inner: notifier }, )) } impl Driver { pub(crate) fn bind_current_thread(&self) { self.ring.bind_current_thread(); } pub(crate) fn unbind_current_thread(&self) { self.ring.unbind_current_thread(); } /// Polls the driver without blocking. pub fn poll(&self) -> io::Result> { let mut ready = ReadyEvents::default(); let saw_any = self .ring .drain_completions(|cqe| self.process_cqe(cqe, &mut ready)); #[cfg(debug_assertions)] if saw_any { tracing::trace!( target: trace_targets::DRIVER, event = "poll_ready", timer_ready = ready.timer, wake_ready = ready.wake, "driver poll produced ready events" ); } if saw_any { Ok(Some(ready)) } else { Ok(None) } } /// Blocks until at least one completion is available. pub fn wait(&self) -> io::Result<()> { #[cfg(debug_assertions)] tracing::trace!( target: trace_targets::DRIVER, event = "wait", "waiting for driver completion" ); self.ring.wait_for_cqe() } /// Updates the currently armed timer deadline. /// /// Passing `None` removes any active timer. pub fn rearm_timer(&self, deadline: Option) -> io::Result<()> { #[cfg(debug_assertions)] tracing::trace!( target: trace_targets::TIMER, event = "rearm_timer", deadline_ns = deadline.map(|value| value.as_nanos() as u64), "rearming driver timer" ); match (self.active_timer_token.get(), deadline) { (Some(active), Some(deadline)) => { self.ring.submit_timeout_update(active, deadline)?; } (Some(active), None) => { self.active_timer_token.set(None); self.ring .submit_timeout_remove(active, self.next_token(CompletionKind::TimerRemove))?; } (None, Some(deadline)) => { let token = self.next_token(CompletionKind::Timer); self.active_timer_token.set(Some(token)); self.ring.submit_timeout(token, deadline)?; } (None, None) => {} } Ok(()) } pub(crate) fn submit_operation( &self, fill: impl FnOnce(&mut IoUringSqe), on_complete: impl FnOnce(IoUringCqe) + Send + 'static, ) -> io::Result { let token = self.next_token(CompletionKind::Operation); #[cfg(debug_assertions)] tracing::trace!( target: trace_targets::ASYNC, event = "submit_operation", token, "submitting async driver operation" ); self.completions .borrow_mut() .insert(token, Box::new(on_complete)); if let Err(error) = self.ring.submit_with_token(token, fill) { let _ = self.completions.borrow_mut().remove(&token); return Err(error); } Ok(token) } pub(crate) fn cancel_operation(&self, token: u64) -> io::Result<()> { #[cfg(debug_assertions)] tracing::trace!( target: trace_targets::ASYNC, event = "cancel_operation", token, "submitting async driver cancellation" ); self.ring .submit_with_token(self.next_token(CompletionKind::OperationCancel), |sqe| { sqe.opcode = IORING_OP_ASYNC_CANCEL; sqe.fd = -1; sqe.addr = token; }) } /// Drains the accumulated wake notification count. pub fn drain_wake(&self) -> io::Result { let wakes = self.pending_wakes.replace(0); if wakes == 0 { Err(io::Error::new( io::ErrorKind::WouldBlock, "no wake completions are pending", )) } else { Ok(wakes) } } /// Drains the accumulated timer-expiration count. pub fn drain_timer(&self) -> io::Result { let timers = self.pending_timers.replace(0); if timers == 0 { Err(io::Error::new( io::ErrorKind::WouldBlock, "no timer completions are pending", )) } else { Ok(timers) } } fn process_cqe(&self, cqe: IoUringCqe, ready: &mut ReadyEvents) { #[cfg(debug_assertions)] tracing::trace!( target: trace_targets::DRIVER, event = "process_cqe", user_data = cqe.user_data, result = cqe.res, "processing io_uring completion" ); if cqe.user_data == WAKE_TARGET_TOKEN { ready.wake = true; let wakes = cqe.res.max(1) as u64; self.pending_wakes .set(self.pending_wakes.get().saturating_add(wakes)); return; } match decode_token_kind(cqe.user_data) { Some(CompletionKind::Timer) => { if self.active_timer_token.get() == Some(cqe.user_data) { self.active_timer_token.set(None); } if cqe.res == -libc::ETIME { ready.timer = true; self.pending_timers .set(self.pending_timers.get().saturating_add(1)); } } Some(CompletionKind::Operation) => { if let Some(callback) = self.completions.borrow_mut().remove(&cqe.user_data) { callback(cqe); } } Some(CompletionKind::TimerRemove) | Some(CompletionKind::NotifySend) | Some(CompletionKind::OperationCancel) | None => {} } } fn next_token(&self, kind: CompletionKind) -> u64 { let seq = self.next_token.get(); self.next_token.set(seq.wrapping_add(1)); make_token(kind, seq) } } impl Drop for Driver { fn drop(&mut self) { tracing::debug!( target: trace_targets::DRIVER, event = "drop_driver", "dropping runtime driver" ); self.notifier.closed.store(true, Ordering::Release); } } /// Returns the current monotonic time used by the runtime timer system. pub fn monotonic_now() -> io::Result { let mut now = std::mem::MaybeUninit::::uninit(); let result = unsafe { libc::clock_gettime(libc::CLOCK_MONOTONIC, now.as_mut_ptr()) }; if result == -1 { return Err(io::Error::last_os_error()); } let now = unsafe { now.assume_init() }; Ok(Duration::new(now.tv_sec as u64, now.tv_nsec as u32)) } fn make_token(kind: CompletionKind, seq: u64) -> u64 { ((kind as u64) << TOKEN_KIND_SHIFT) | (seq & !TOKEN_KIND_MASK) } fn decode_token_kind(token: u64) -> Option { match ((token & TOKEN_KIND_MASK) >> TOKEN_KIND_SHIFT) as u8 { 1 => Some(CompletionKind::Timer), 2 => Some(CompletionKind::TimerRemove), 3 => Some(CompletionKind::NotifySend), 4 => Some(CompletionKind::Operation), 5 => Some(CompletionKind::OperationCancel), _ => None, } } #[cfg(test)] mod tests { use super::{create_driver, monotonic_now}; use std::thread; use std::time::Duration; #[test] fn notifier_wakes_target_ring() { let (sender, _) = create_driver().expect("sender driver should initialize"); sender.bind_current_thread(); let (target, notifier) = create_driver().expect("target driver should initialize"); notifier.notify().expect("notify should succeed"); let ready = loop { if let Some(ready) = target.poll().expect("poll should succeed") { break ready; } thread::sleep(Duration::from_millis(1)); }; assert!(ready.wake); assert!(!ready.timer); assert_eq!(target.drain_wake().expect("wake drain should succeed"), 1); sender.unbind_current_thread(); } #[test] fn notifier_wakes_target_ring_from_plain_thread() { let (target, notifier) = create_driver().expect("target driver should initialize"); thread::spawn(move || { notifier.notify().expect("notify should succeed"); }) .join() .expect("notifier thread should exit cleanly"); let ready = loop { if let Some(ready) = target.poll().expect("poll should succeed") { break ready; } thread::sleep(Duration::from_millis(1)); }; assert!(ready.wake); assert!(!ready.timer); assert_eq!(target.drain_wake().expect("wake drain should succeed"), 1); } #[test] fn timeout_reports_deadlines() { let (driver, _notifier) = create_driver().expect("driver should initialize"); let deadline = monotonic_now().expect("clock should work") + Duration::from_millis(20); driver .rearm_timer(Some(deadline)) .expect("timer should arm"); let ready = loop { if let Some(ready) = driver.poll().expect("poll should succeed") { break ready; } thread::sleep(Duration::from_millis(5)); }; assert!(ready.timer); assert!(!ready.wake); assert_eq!(driver.drain_timer().expect("timer drain should succeed"), 1); } }