Skip to main content

pubhubs/misc/
stream_ext.rs

1//! Tools for dealing with streams
2
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures::stream::Stream;
7
8/// Extension trait for [`Stream`]s
9pub trait StreamExt: Stream + Sized {
10    /// Yields items from the current, first, stream until an item from the `other`,
11    /// second, stream becomes available.  At that point the first stream is dropped,
12    /// and only items from the second stream are yielded.
13    fn until_overridden_by<Other: Stream<Item = Self::Item>>(
14        self,
15        other: Other,
16    ) -> UntilOverriddenBy<Self, Other>;
17
18    /// Like [`futures::stream::Fuse`], but can be 'tripped' to cut the stream short.
19    fn breaker(self) -> Breaker<Self>;
20}
21
22impl<S: Stream> StreamExt for S {
23    fn until_overridden_by<Other: Stream<Item = Self::Item>>(
24        self,
25        other: Other,
26    ) -> UntilOverriddenBy<Self, Other> {
27        UntilOverriddenBy {
28            a: self.breaker(),
29            b: other.breaker(),
30        }
31    }
32
33    fn breaker(self) -> Breaker<Self> {
34        Breaker { inner: Some(self) }
35    }
36}
37
38pin_project_lite::pin_project! {
39/// Return type of [`StreamExt::breaker`]
40pub struct Breaker<S : Stream>{
41#[pin]
42inner: Option<S>
43}
44}
45
46impl<S: Stream> Stream for Breaker<S> {
47    type Item = S::Item;
48
49    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
50        let Some(s) = self.as_mut().project().inner.as_pin_mut() else {
51            return Poll::Ready(None);
52        };
53
54        let result = s.poll_next(cx);
55
56        if matches!(result, Poll::Ready(None)) {
57            self.trip();
58        }
59
60        result
61    }
62}
63
64impl<S: Stream> futures::stream::FusedStream for Breaker<S> {
65    fn is_terminated(&self) -> bool {
66        self.inner.is_none()
67    }
68}
69
70impl<S: Stream> Breaker<S> {
71    /// Drop the underlying stream (if it was not already);
72    /// [`Stream::poll_next`] will return `Poll::Ready(None)` from this point onwards.
73    pub fn trip(self: Pin<&mut Self>) {
74        self.project().inner.set(None)
75    }
76}
77
78pin_project_lite::pin_project! {
79/// Return type of [`StreamExt::until_overridden_by`].
80pub struct UntilOverriddenBy<A, B>
81where
82    A: Stream,
83    B: Stream<Item = A::Item>,
84{
85    #[pin]
86    a: Breaker<A>,
87    #[pin]
88    b: Breaker<B>,
89}
90}
91
92impl<A, B> Stream for UntilOverriddenBy<A, B>
93where
94    A: Stream,
95    B: Stream<Item = A::Item>,
96{
97    type Item = A::Item;
98
99    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
100        let this = self.project();
101
102        // b.poll_next is cheap when b is terminated
103        match this.b.poll_next(cx) {
104            Poll::Pending | Poll::Ready(None) => this.a.poll_next(cx),
105            result @ Poll::Ready(Some(..)) => {
106                this.a.trip();
107                result
108            }
109        }
110    }
111}
112
113impl<A: Stream, B: Stream<Item = A::Item>> futures::stream::FusedStream
114    for UntilOverriddenBy<A, B>
115{
116    fn is_terminated(&self) -> bool {
117        self.a.is_terminated() && self.b.is_terminated()
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use futures::StreamExt as _;
125
126    fn chan(
127        items: &[i32],
128    ) -> (
129        futures::channel::mpsc::UnboundedSender<i32>,
130        impl Stream<Item = i32>,
131    ) {
132        let (tx, rx) = futures::channel::mpsc::unbounded();
133        for &item in items {
134            tx.unbounded_send(item).unwrap();
135        }
136        (tx, rx)
137    }
138
139    /// B overrides A after A yields; the remaining item in A is dropped
140    #[test]
141    fn b_overrides_a() {
142        let (a_tx, a) = chan(&[1, 2, 3]); // 3 will be overridden; B starts pending
143        let (b_tx, b) = chan(&[]);
144        let mut s = a.until_overridden_by(b);
145
146        tokio_test::assert_ready_eq!(tokio_test::task::spawn(s.next()).poll(), Some(1));
147        tokio_test::assert_ready_eq!(tokio_test::task::spawn(s.next()).poll(), Some(2));
148
149        b_tx.unbounded_send(10).unwrap();
150        drop(b_tx);
151
152        tokio_test::assert_ready_eq!(tokio_test::task::spawn(s.next()).poll(), Some(10));
153        tokio_test::assert_ready_eq!(tokio_test::task::spawn(s.next()).poll(), None);
154
155        assert!(a_tx.unbounded_send(99).is_err()); // A (and its buffered 3) was dropped
156    }
157
158    /// B ends without overriding; A is not dropped and continues
159    #[test]
160    fn b_ends_without_overriding() {
161        let (a_tx, a) = chan(&[]);
162        let (b_tx, b) = chan(&[]);
163        drop(b_tx); // B ends immediately
164        let mut s = a.until_overridden_by(b);
165
166        a_tx.unbounded_send(1).unwrap();
167        a_tx.unbounded_send(2).unwrap();
168        tokio_test::assert_ready_eq!(tokio_test::task::spawn(s.next()).poll(), Some(1));
169        tokio_test::assert_ready_eq!(tokio_test::task::spawn(s.next()).poll(), Some(2));
170        assert!(a_tx.unbounded_send(99).is_ok()); // A was not dropped
171    }
172
173    /// Both pending — stream is pending
174    #[test]
175    fn both_pending() {
176        let (_a_tx, a) = chan(&[]);
177        let (_b_tx, b) = chan(&[]);
178        let mut s = a.until_overridden_by(b);
179        tokio_test::assert_pending!(tokio_test::task::spawn(s.next()).poll());
180    }
181}