tokio_util/task/abort_on_drop.rs
1//! An [`AbortOnDropHandle`] is like a [`JoinHandle`], except that it
2//! will abort the task as soon as it is dropped.
3//!
4//! Correspondingly, an [`AbortOnDrop`] is like a [`AbortHandle`] that will abort
5//! the task as soon as it is dropped.
6
7use tokio::task::{AbortHandle, JoinError, JoinHandle};
8
9use std::{
10 future::Future,
11 mem::ManuallyDrop,
12 pin::Pin,
13 task::{Context, Poll},
14};
15
16/// A wrapper around a [`tokio::task::JoinHandle`],
17/// which [aborts] the task when it is dropped.
18///
19/// [aborts]: tokio::task::JoinHandle::abort
20#[must_use = "Dropping the handle aborts the task immediately"]
21pub struct AbortOnDropHandle<T>(JoinHandle<T>);
22
23impl<T> Drop for AbortOnDropHandle<T> {
24 fn drop(&mut self) {
25 self.abort()
26 }
27}
28
29impl<T> AbortOnDropHandle<T> {
30 /// Create an [`AbortOnDropHandle`] from a [`JoinHandle`].
31 pub fn new(handle: JoinHandle<T>) -> Self {
32 Self(handle)
33 }
34
35 /// Abort the task associated with this handle,
36 /// equivalent to [`JoinHandle::abort`].
37 #[inline]
38 pub fn abort(&self) {
39 self.0.abort()
40 }
41
42 /// Checks if the task associated with this handle is finished,
43 /// equivalent to [`JoinHandle::is_finished`].
44 #[inline]
45 pub fn is_finished(&self) -> bool {
46 self.0.is_finished()
47 }
48
49 /// Returns a new [`AbortHandle`] that can be used to remotely abort this task,
50 /// equivalent to [`JoinHandle::abort_handle`].
51 pub fn abort_handle(&self) -> AbortHandle {
52 self.0.abort_handle()
53 }
54
55 /// Cancels aborting on drop and returns the original [`JoinHandle`].
56 pub fn detach(self) -> JoinHandle<T> {
57 // Avoid invoking `AbortOnDropHandle`'s `Drop` impl
58 let this = ManuallyDrop::new(self);
59 // SAFETY: `&this.0` is a reference, so it is certainly initialized, and
60 // it won't be double-dropped because it's in a `ManuallyDrop`
61 unsafe { std::ptr::read(&this.0) }
62 }
63}
64
65impl<T> std::fmt::Debug for AbortOnDropHandle<T> {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 f.debug_struct("AbortOnDropHandle")
68 .field("id", &self.0.id())
69 .finish()
70 }
71}
72
73impl<T> Future for AbortOnDropHandle<T> {
74 type Output = Result<T, JoinError>;
75
76 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
77 Pin::new(&mut self.0).poll(cx)
78 }
79}
80
81impl<T> AsRef<JoinHandle<T>> for AbortOnDropHandle<T> {
82 fn as_ref(&self) -> &JoinHandle<T> {
83 &self.0
84 }
85}
86
87/// A wrapper around a [`tokio::task::AbortHandle`],
88/// which [aborts] the task when it is dropped.
89///
90/// Unlike [`AbortOnDropHandle`], [`AbortOnDrop`] cannot be `.await`ed for a result.
91///
92/// It has no generic parameter, making it suitable when you only need to keep
93/// a task handle in a struct and do not care about the output.
94///
95/// [aborts]: tokio::task::AbortHandle::abort
96#[must_use = "Dropping the handle aborts the task immediately"]
97pub struct AbortOnDrop(AbortHandle);
98
99impl Drop for AbortOnDrop {
100 fn drop(&mut self) {
101 self.abort()
102 }
103}
104
105impl AbortOnDrop {
106 /// Create an [`AbortOnDrop`] from a [`AbortHandle`].
107 pub fn new(handle: AbortHandle) -> Self {
108 Self(handle)
109 }
110
111 /// Abort the task associated with this handle,
112 /// equivalent to [`AbortHandle::abort`].
113 #[inline]
114 pub fn abort(&self) {
115 self.0.abort()
116 }
117
118 /// Checks if the task associated with this handle is finished,
119 /// equivalent to [`AbortHandle::is_finished`].
120 #[inline]
121 pub fn is_finished(&self) -> bool {
122 self.0.is_finished()
123 }
124
125 /// Cancels aborting on drop and returns the original [`AbortHandle`].
126 pub fn detach(self) -> AbortHandle {
127 // Avoid invoking `AbortOnDrop`'s `Drop` impl
128 let this = ManuallyDrop::new(self);
129 // SAFETY: `&this.0` is a reference, so it is certainly initialized, and
130 // it won't be double-dropped because it's in a `ManuallyDrop`
131 unsafe { std::ptr::read(&this.0) }
132 }
133}
134
135impl std::fmt::Debug for AbortOnDrop {
136 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137 f.debug_struct("AbortOnDrop")
138 .field("id", &self.0.id())
139 .finish()
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 /// A simple type that does not implement [`std::fmt::Debug`].
148 struct NotDebug;
149
150 fn is_debug<T: std::fmt::Debug>() {}
151
152 #[test]
153 fn assert_debug() {
154 is_debug::<AbortOnDrop>();
155 is_debug::<AbortOnDropHandle<NotDebug>>();
156 }
157}