pubhubs/misc/
stream_ext.rs1use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures::stream::Stream;
7
8pub trait StreamExt: Stream + Sized {
10 fn until_overridden_by<Other: Stream<Item = Self::Item>>(
14 self,
15 other: Other,
16 ) -> UntilOverriddenBy<Self, Other>;
17
18 fn breaker(self) -> Breaker<Self>;
20}
21
22impl<S: Stream> StreamExt for S {
23 fn until_overridden_by<Other: Stream<Item = Self::Item>>(
24 self,
25 other: Other,
26 ) -> UntilOverriddenBy<Self, Other> {
27 UntilOverriddenBy {
28 a: self.breaker(),
29 b: other.breaker(),
30 }
31 }
32
33 fn breaker(self) -> Breaker<Self> {
34 Breaker { inner: Some(self) }
35 }
36}
37
38pin_project_lite::pin_project! {
39pub struct Breaker<S : Stream>{
41#[pin]
42inner: Option<S>
43}
44}
45
46impl<S: Stream> Stream for Breaker<S> {
47 type Item = S::Item;
48
49 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
50 let Some(s) = self.as_mut().project().inner.as_pin_mut() else {
51 return Poll::Ready(None);
52 };
53
54 let result = s.poll_next(cx);
55
56 if matches!(result, Poll::Ready(None)) {
57 self.trip();
58 }
59
60 result
61 }
62}
63
64impl<S: Stream> futures::stream::FusedStream for Breaker<S> {
65 fn is_terminated(&self) -> bool {
66 self.inner.is_none()
67 }
68}
69
70impl<S: Stream> Breaker<S> {
71 pub fn trip(self: Pin<&mut Self>) {
74 self.project().inner.set(None)
75 }
76}
77
78pin_project_lite::pin_project! {
79pub struct UntilOverriddenBy<A, B>
81where
82 A: Stream,
83 B: Stream<Item = A::Item>,
84{
85 #[pin]
86 a: Breaker<A>,
87 #[pin]
88 b: Breaker<B>,
89}
90}
91
92impl<A, B> Stream for UntilOverriddenBy<A, B>
93where
94 A: Stream,
95 B: Stream<Item = A::Item>,
96{
97 type Item = A::Item;
98
99 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
100 let this = self.project();
101
102 match this.b.poll_next(cx) {
104 Poll::Pending | Poll::Ready(None) => this.a.poll_next(cx),
105 result @ Poll::Ready(Some(..)) => {
106 this.a.trip();
107 result
108 }
109 }
110 }
111}
112
113impl<A: Stream, B: Stream<Item = A::Item>> futures::stream::FusedStream
114 for UntilOverriddenBy<A, B>
115{
116 fn is_terminated(&self) -> bool {
117 self.a.is_terminated() && self.b.is_terminated()
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use futures::StreamExt as _;
125
126 fn chan(
127 items: &[i32],
128 ) -> (
129 futures::channel::mpsc::UnboundedSender<i32>,
130 impl Stream<Item = i32>,
131 ) {
132 let (tx, rx) = futures::channel::mpsc::unbounded();
133 for &item in items {
134 tx.unbounded_send(item).unwrap();
135 }
136 (tx, rx)
137 }
138
139 #[test]
141 fn b_overrides_a() {
142 let (a_tx, a) = chan(&[1, 2, 3]); let (b_tx, b) = chan(&[]);
144 let mut s = a.until_overridden_by(b);
145
146 tokio_test::assert_ready_eq!(tokio_test::task::spawn(s.next()).poll(), Some(1));
147 tokio_test::assert_ready_eq!(tokio_test::task::spawn(s.next()).poll(), Some(2));
148
149 b_tx.unbounded_send(10).unwrap();
150 drop(b_tx);
151
152 tokio_test::assert_ready_eq!(tokio_test::task::spawn(s.next()).poll(), Some(10));
153 tokio_test::assert_ready_eq!(tokio_test::task::spawn(s.next()).poll(), None);
154
155 assert!(a_tx.unbounded_send(99).is_err()); }
157
158 #[test]
160 fn b_ends_without_overriding() {
161 let (a_tx, a) = chan(&[]);
162 let (b_tx, b) = chan(&[]);
163 drop(b_tx); let mut s = a.until_overridden_by(b);
165
166 a_tx.unbounded_send(1).unwrap();
167 a_tx.unbounded_send(2).unwrap();
168 tokio_test::assert_ready_eq!(tokio_test::task::spawn(s.next()).poll(), Some(1));
169 tokio_test::assert_ready_eq!(tokio_test::task::spawn(s.next()).poll(), Some(2));
170 assert!(a_tx.unbounded_send(99).is_ok()); }
172
173 #[test]
175 fn both_pending() {
176 let (_a_tx, a) = chan(&[]);
177 let (_b_tx, b) = chan(&[]);
178 let mut s = a.until_overridden_by(b);
179 tokio_test::assert_pending!(tokio_test::task::spawn(s.next()).poll());
180 }
181}