Skip to main content

tokio_test/
io.rs

1#![cfg(not(loom))]
2
3//! A mock type implementing [`AsyncRead`] and [`AsyncWrite`].
4//!
5//!
6//! # Overview
7//!
8//! Provides a type that implements [`AsyncRead`] + [`AsyncWrite`] that can be configured
9//! to handle an arbitrary sequence of read and write operations. This is useful
10//! for writing unit tests for networking services as using an actual network
11//! type is fairly non deterministic.
12//!
13//! # Usage
14//!
15//! Attempting to write data that the mock isn't expecting will result in a
16//! panic.
17//!
18//! [`AsyncRead`]: tokio::io::AsyncRead
19//! [`AsyncWrite`]: tokio::io::AsyncWrite
20
21use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
22use tokio::sync::mpsc;
23use tokio::time::{self, Duration, Instant, Sleep};
24use tokio_stream::wrappers::UnboundedReceiverStream;
25
26use futures_core::Stream;
27use std::collections::VecDeque;
28use std::fmt;
29use std::future::Future;
30use std::pin::Pin;
31use std::sync::Arc;
32use std::task::{self, ready, Poll, Waker};
33use std::{cmp, io};
34
35/// An I/O object that follows a predefined script.
36///
37/// This value is created by `Builder` and implements `AsyncRead` + `AsyncWrite`. It
38/// follows the scenario described by the builder and panics otherwise.
39#[derive(Debug)]
40pub struct Mock {
41    inner: Inner,
42}
43
44/// A handle to send additional actions to the related `Mock`.
45#[derive(Debug)]
46pub struct Handle {
47    tx: mpsc::UnboundedSender<Action>,
48}
49
50/// Builds `Mock` instances.
51#[derive(Debug, Clone, Default)]
52pub struct Builder {
53    // Sequence of actions for the Mock to take
54    actions: VecDeque<Action>,
55    name: String,
56}
57
58#[derive(Debug, Clone)]
59enum Action {
60    Read(Vec<u8>),
61    Write(Vec<u8>),
62    Wait(Duration),
63    // Wrapped in Arc so that Builder can be cloned and Send.
64    // Mock is not cloned as does not need to check Rc for ref counts.
65    ReadError(Option<Arc<io::Error>>),
66    WriteError(Option<Arc<io::Error>>),
67}
68
69struct Inner {
70    actions: VecDeque<Action>,
71    waiting: Option<Instant>,
72    sleep: Option<Pin<Box<Sleep>>>,
73    read_wait: Option<Waker>,
74    rx: UnboundedReceiverStream<Action>,
75    name: String,
76}
77
78impl Builder {
79    /// Return a new, empty `Builder`.
80    pub fn new() -> Self {
81        Self::default()
82    }
83
84    /// Sequence a `read` operation.
85    ///
86    /// The next operation in the mock's script will be to expect a `read` call
87    /// and return `buf`.
88    pub fn read(&mut self, buf: &[u8]) -> &mut Self {
89        self.actions.push_back(Action::Read(buf.into()));
90        self
91    }
92
93    /// Sequence a `read` operation that produces an error.
94    ///
95    /// The next operation in the mock's script will be to expect a `read` call
96    /// and return `error`.
97    pub fn read_error(&mut self, error: io::Error) -> &mut Self {
98        let error = Some(error.into());
99        self.actions.push_back(Action::ReadError(error));
100        self
101    }
102
103    /// Sequence a `write` operation.
104    ///
105    /// The next operation in the mock's script will be to expect a `write`
106    /// call.
107    pub fn write(&mut self, buf: &[u8]) -> &mut Self {
108        self.actions.push_back(Action::Write(buf.into()));
109        self
110    }
111
112    /// Sequence a `write` operation that produces an error.
113    ///
114    /// The next operation in the mock's script will be to expect a `write`
115    /// call that provides `error`.
116    pub fn write_error(&mut self, error: io::Error) -> &mut Self {
117        let error = Some(error.into());
118        self.actions.push_back(Action::WriteError(error));
119        self
120    }
121
122    /// Sequence a wait.
123    ///
124    /// The next operation in the mock's script will be to wait without doing so
125    /// for `duration` amount of time.
126    pub fn wait(&mut self, duration: Duration) -> &mut Self {
127        let duration = cmp::max(duration, Duration::from_millis(1));
128        self.actions.push_back(Action::Wait(duration));
129        self
130    }
131
132    /// Set name of the mock IO object to include in panic messages and debug output
133    pub fn name(&mut self, name: impl Into<String>) -> &mut Self {
134        self.name = name.into();
135        self
136    }
137
138    /// Build a `Mock` value according to the defined script.
139    pub fn build(&mut self) -> Mock {
140        let (mock, _) = self.build_with_handle();
141        mock
142    }
143
144    /// Build a `Mock` value paired with a handle
145    pub fn build_with_handle(&mut self) -> (Mock, Handle) {
146        let (inner, handle) = Inner::new(self.actions.clone(), self.name.clone());
147
148        let mock = Mock { inner };
149
150        (mock, handle)
151    }
152}
153
154impl Handle {
155    /// Sequence a `read` operation.
156    ///
157    /// The next operation in the mock's script will be to expect a `read` call
158    /// and return `buf`.
159    pub fn read(&mut self, buf: &[u8]) -> &mut Self {
160        self.tx.send(Action::Read(buf.into())).unwrap();
161        self
162    }
163
164    /// Sequence a `read` operation error.
165    ///
166    /// The next operation in the mock's script will be to expect a `read` call
167    /// and return `error`.
168    pub fn read_error(&mut self, error: io::Error) -> &mut Self {
169        let error = Some(error.into());
170        self.tx.send(Action::ReadError(error)).unwrap();
171        self
172    }
173
174    /// Sequence a `write` operation.
175    ///
176    /// The next operation in the mock's script will be to expect a `write`
177    /// call.
178    pub fn write(&mut self, buf: &[u8]) -> &mut Self {
179        self.tx.send(Action::Write(buf.into())).unwrap();
180        self
181    }
182
183    /// Sequence a `write` operation error.
184    ///
185    /// The next operation in the mock's script will be to expect a `write`
186    /// call error.
187    pub fn write_error(&mut self, error: io::Error) -> &mut Self {
188        let error = Some(error.into());
189        self.tx.send(Action::WriteError(error)).unwrap();
190        self
191    }
192}
193
194impl Inner {
195    fn new(actions: VecDeque<Action>, name: String) -> (Inner, Handle) {
196        let (tx, rx) = mpsc::unbounded_channel();
197
198        let rx = UnboundedReceiverStream::new(rx);
199
200        let inner = Inner {
201            actions,
202            sleep: None,
203            read_wait: None,
204            rx,
205            waiting: None,
206            name,
207        };
208
209        let handle = Handle { tx };
210
211        (inner, handle)
212    }
213
214    fn poll_action(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<Action>> {
215        Pin::new(&mut self.rx).poll_next(cx)
216    }
217
218    fn read(&mut self, dst: &mut ReadBuf<'_>) -> io::Result<()> {
219        match self.action() {
220            Some(&mut Action::Read(ref mut data)) => {
221                // Figure out how much to copy
222                let n = cmp::min(dst.remaining(), data.len());
223
224                // Copy the data into the `dst` slice
225                dst.put_slice(&data[..n]);
226
227                // Drain the data from the source
228                data.drain(..n);
229
230                Ok(())
231            }
232            Some(&mut Action::ReadError(ref mut err)) => {
233                // As the
234                let err = err.take().expect("Should have been removed from actions.");
235                let err = Arc::try_unwrap(err).expect("There are no other references.");
236                Err(err)
237            }
238            Some(_) => {
239                // Either waiting or expecting a write
240                Err(io::ErrorKind::WouldBlock.into())
241            }
242            None => Ok(()),
243        }
244    }
245
246    fn write(&mut self, mut src: &[u8]) -> io::Result<usize> {
247        let mut ret = 0;
248
249        if self.actions.is_empty() {
250            return Err(io::ErrorKind::BrokenPipe.into());
251        }
252
253        if let Some(&mut Action::Wait(..)) = self.action() {
254            return Err(io::ErrorKind::WouldBlock.into());
255        }
256
257        if let Some(&mut Action::WriteError(ref mut err)) = self.action() {
258            let err = err.take().expect("Should have been removed from actions.");
259            let err = Arc::try_unwrap(err).expect("There are no other references.");
260            return Err(err);
261        }
262
263        for i in 0..self.actions.len() {
264            match self.actions[i] {
265                Action::Write(ref mut expect) => {
266                    let n = cmp::min(src.len(), expect.len());
267
268                    assert_eq!(&src[..n], &expect[..n], "name={} i={}", self.name, i);
269
270                    // Drop data that was matched
271                    expect.drain(..n);
272                    src = &src[n..];
273
274                    ret += n;
275
276                    if src.is_empty() {
277                        return Ok(ret);
278                    }
279                }
280                Action::Wait(..) | Action::WriteError(..) => {
281                    break;
282                }
283                _ => {}
284            }
285
286            // TODO: remove write
287        }
288
289        Ok(ret)
290    }
291
292    fn remaining_wait(&mut self) -> Option<Duration> {
293        match self.action() {
294            Some(&mut Action::Wait(dur)) => Some(dur),
295            _ => None,
296        }
297    }
298
299    fn action(&mut self) -> Option<&mut Action> {
300        loop {
301            if self.actions.is_empty() {
302                return None;
303            }
304
305            match self.actions[0] {
306                Action::Read(ref mut data) => {
307                    if !data.is_empty() {
308                        break;
309                    }
310                }
311                Action::Write(ref mut data) => {
312                    if !data.is_empty() {
313                        break;
314                    }
315                }
316                Action::Wait(ref mut dur) => {
317                    if let Some(until) = self.waiting {
318                        let now = Instant::now();
319
320                        if now < until {
321                            break;
322                        } else {
323                            self.waiting = None;
324                        }
325                    } else {
326                        self.waiting = Some(Instant::now() + *dur);
327                        break;
328                    }
329                }
330                Action::ReadError(ref mut error) | Action::WriteError(ref mut error) => {
331                    if error.is_some() {
332                        break;
333                    }
334                }
335            }
336
337            let _action = self.actions.pop_front();
338        }
339
340        self.actions.front_mut()
341    }
342}
343
344// ===== impl Inner =====
345
346impl Mock {
347    fn maybe_wakeup_reader(&mut self) {
348        match self.inner.action() {
349            Some(&mut Action::Read(_)) | Some(&mut Action::ReadError(_)) | None => {
350                if let Some(waker) = self.inner.read_wait.take() {
351                    waker.wake();
352                }
353            }
354            _ => {}
355        }
356    }
357}
358
359impl AsyncRead for Mock {
360    fn poll_read(
361        mut self: Pin<&mut Self>,
362        cx: &mut task::Context<'_>,
363        buf: &mut ReadBuf<'_>,
364    ) -> Poll<io::Result<()>> {
365        loop {
366            if let Some(ref mut sleep) = self.inner.sleep {
367                ready!(Pin::new(sleep).poll(cx));
368            }
369
370            // If a sleep is set, it has already fired
371            self.inner.sleep = None;
372
373            // Capture 'filled' to monitor if it changed
374            let filled = buf.filled().len();
375
376            match self.inner.read(buf) {
377                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
378                    if let Some(rem) = self.inner.remaining_wait() {
379                        let until = Instant::now() + rem;
380                        self.inner.sleep = Some(Box::pin(time::sleep_until(until)));
381                    } else {
382                        self.inner.read_wait = Some(cx.waker().clone());
383                        return Poll::Pending;
384                    }
385                }
386                Ok(()) => {
387                    if buf.filled().len() == filled {
388                        match ready!(self.inner.poll_action(cx)) {
389                            Some(action) => {
390                                self.inner.actions.push_back(action);
391                                continue;
392                            }
393                            None => {
394                                return Poll::Ready(Ok(()));
395                            }
396                        }
397                    } else {
398                        return Poll::Ready(Ok(()));
399                    }
400                }
401                Err(e) => return Poll::Ready(Err(e)),
402            }
403        }
404    }
405}
406
407impl AsyncWrite for Mock {
408    fn poll_write(
409        mut self: Pin<&mut Self>,
410        cx: &mut task::Context<'_>,
411        buf: &[u8],
412    ) -> Poll<io::Result<usize>> {
413        loop {
414            if let Some(ref mut sleep) = self.inner.sleep {
415                ready!(Pin::new(sleep).poll(cx));
416            }
417
418            // If a sleep is set, it has already fired
419            self.inner.sleep = None;
420
421            if self.inner.actions.is_empty() {
422                match self.inner.poll_action(cx) {
423                    Poll::Pending => {
424                        // do not propagate pending
425                    }
426                    Poll::Ready(Some(action)) => {
427                        self.inner.actions.push_back(action);
428                    }
429                    Poll::Ready(None) => {
430                        panic!("unexpected write {}", self.pmsg());
431                    }
432                }
433            }
434
435            match self.inner.write(buf) {
436                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
437                    if let Some(rem) = self.inner.remaining_wait() {
438                        let until = Instant::now() + rem;
439                        self.inner.sleep = Some(Box::pin(time::sleep_until(until)));
440                    } else {
441                        // A race condition (TOCTOU) can occur if the
442                        // timer expires between the `write()` call
443                        // and `remaining_wait()` due to preemption or other
444                        // delays. In this case, the `Wait` action is already popped by
445                        // `action()`, so we continue to the next one.
446                        //
447                        // Consider the following sequence:
448                        //
449                        // poll_write      Inner          action()
450                        //    |--write()--->|               |
451                        //    |             |--action()---->| (returns Wait)
452                        //    |<-WouldBlk---|               |
453                        //    |             |               |
454                        //    |      <--- TIMEOUT! --->     |
455                        //    |  (due to preemption, etc.)  |
456                        //    |             |               |
457                        //    |-rem_wait()->|               |
458                        //    |             |--action()---->| (time's up, pop Wait)
459                        //    |<--None------|               |
460                        //    |             |               |
461                        //    |---continue->| (process next action)
462                        //
463                        // See <https://github.com/tokio-rs/tokio/issues/7881>.
464                        continue;
465                    }
466                }
467                Ok(0) => {
468                    // TODO: Is this correct?
469                    if !self.inner.actions.is_empty() {
470                        return Poll::Pending;
471                    }
472
473                    // TODO: Extract
474                    match ready!(self.inner.poll_action(cx)) {
475                        Some(action) => {
476                            self.inner.actions.push_back(action);
477                            continue;
478                        }
479                        None => {
480                            panic!("unexpected write {}", self.pmsg());
481                        }
482                    }
483                }
484                ret => {
485                    self.maybe_wakeup_reader();
486                    return Poll::Ready(ret);
487                }
488            }
489        }
490    }
491
492    fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
493        Poll::Ready(Ok(()))
494    }
495
496    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
497        Poll::Ready(Ok(()))
498    }
499}
500
501/// Ensures that Mock isn't dropped with data "inside".
502impl Drop for Mock {
503    fn drop(&mut self) {
504        // Avoid double panicking, since makes debugging much harder.
505        if std::thread::panicking() {
506            return;
507        }
508
509        self.inner.actions.iter().for_each(|a| match a {
510            Action::Read(data) => assert!(
511                data.is_empty(),
512                "There is still data left to read. {}",
513                self.pmsg()
514            ),
515            Action::Write(data) => assert!(
516                data.is_empty(),
517                "There is still data left to write. {}",
518                self.pmsg()
519            ),
520            _ => (),
521        });
522    }
523}
524/*
525/// Returns `true` if called from the context of a futures-rs Task
526fn is_task_ctx() -> bool {
527    use std::panic;
528
529    // Save the existing panic hook
530    let h = panic::take_hook();
531
532    // Install a new one that does nothing
533    panic::set_hook(Box::new(|_| {}));
534
535    // Attempt to call the fn
536    let r = panic::catch_unwind(|| task::current()).is_ok();
537
538    // Re-install the old one
539    panic::set_hook(h);
540
541    // Return the result
542    r
543}
544*/
545
546impl fmt::Debug for Inner {
547    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
548        if self.name.is_empty() {
549            write!(f, "Inner {{...}}")
550        } else {
551            write!(f, "Inner {{name={}, ...}}", self.name)
552        }
553    }
554}
555
556struct PanicMsgSnippet<'a>(&'a Inner);
557
558impl<'a> fmt::Display for PanicMsgSnippet<'a> {
559    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
560        if self.0.name.is_empty() {
561            write!(f, "({} actions remain)", self.0.actions.len())
562        } else {
563            write!(
564                f,
565                "(name {}, {} actions remain)",
566                self.0.name,
567                self.0.actions.len()
568            )
569        }
570    }
571}
572
573impl Mock {
574    fn pmsg(&self) -> PanicMsgSnippet<'_> {
575        PanicMsgSnippet(&self.inner)
576    }
577}