Skip to main content

object_store/client/http/
spawn.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::client::{
19    HttpError, HttpErrorKind, HttpRequest, HttpResponse, HttpResponseBody, HttpService,
20};
21use async_trait::async_trait;
22use bytes::Bytes;
23use http::Response;
24use http_body_util::BodyExt;
25use hyper::body::{Body, Frame};
26use std::pin::Pin;
27use std::task::{Context, Poll};
28use thiserror::Error;
29use tokio::runtime::Handle;
30use tokio::task::JoinHandle;
31
32/// Spawn error
33#[derive(Debug, Error)]
34#[error("SpawnError")]
35struct SpawnError {}
36
37impl From<SpawnError> for HttpError {
38    fn from(value: SpawnError) -> Self {
39        Self::new(HttpErrorKind::Interrupted, value)
40    }
41}
42
43/// Wraps a provided [`HttpService`] and runs it on a separate tokio runtime
44///
45/// See example on [`SpawnedReqwestConnector`]
46///
47/// [`SpawnedReqwestConnector`]: crate::client::http::SpawnedReqwestConnector
48#[derive(Debug)]
49pub struct SpawnService<T: HttpService + Clone> {
50    inner: T,
51    runtime: Handle,
52}
53
54impl<T: HttpService + Clone> SpawnService<T> {
55    /// Creates a new [`SpawnService`] from the provided
56    pub fn new(inner: T, runtime: Handle) -> Self {
57        Self { inner, runtime }
58    }
59}
60
61#[async_trait]
62impl<T: HttpService + Clone> HttpService for SpawnService<T> {
63    async fn call(&self, req: HttpRequest) -> Result<HttpResponse, HttpError> {
64        let inner = self.inner.clone();
65        let (send, recv) = tokio::sync::oneshot::channel();
66
67        // We use an unbounded channel to prevent backpressure across the runtime boundary
68        // which could in turn starve the underlying IO operations
69        let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
70
71        let handle = SpawnHandle(self.runtime.spawn(async move {
72            let r = match HttpService::call(&inner, req).await {
73                Ok(resp) => resp,
74                Err(e) => {
75                    let _ = send.send(Err(e));
76                    return;
77                }
78            };
79
80            let (parts, mut body) = r.into_parts();
81            if send.send(Ok(parts)).is_err() {
82                return;
83            }
84
85            while let Some(x) = body.frame().await {
86                if sender.send(x).is_err() {
87                    return;
88                }
89            }
90        }));
91
92        let parts = recv.await.map_err(|_| SpawnError {})??;
93
94        Ok(Response::from_parts(
95            parts,
96            HttpResponseBody::new(SpawnBody {
97                stream: receiver,
98                _worker: handle,
99            }),
100        ))
101    }
102}
103
104/// A wrapper around a [`JoinHandle`] that aborts on drop
105struct SpawnHandle(JoinHandle<()>);
106impl Drop for SpawnHandle {
107    fn drop(&mut self) {
108        self.0.abort();
109    }
110}
111
112type StreamItem = Result<Frame<Bytes>, HttpError>;
113
114struct SpawnBody {
115    stream: tokio::sync::mpsc::UnboundedReceiver<StreamItem>,
116    _worker: SpawnHandle,
117}
118
119impl Body for SpawnBody {
120    type Data = Bytes;
121    type Error = HttpError;
122
123    fn poll_frame(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<StreamItem>> {
124        self.stream.poll_recv(cx)
125    }
126}
127
128#[cfg(not(target_arch = "wasm32"))]
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use crate::RetryConfig;
133    use crate::client::HttpClient;
134    use crate::client::mock_server::MockServer;
135    use crate::client::retry::RetryExt;
136
137    async fn test_client(client: HttpClient) {
138        let (send, recv) = tokio::sync::oneshot::channel();
139
140        let mock = MockServer::new().await;
141        mock.push(Response::new("BANANAS".to_string()));
142
143        let url = mock.url().to_string();
144        let thread = std::thread::spawn(|| {
145            futures_executor::block_on(async move {
146                let retry = RetryConfig::default();
147                let ret = client.get(url).send_retry(&retry).await.unwrap();
148                let payload = ret.into_body().bytes().await.unwrap();
149                assert_eq!(payload.as_ref(), b"BANANAS");
150                let _ = send.send(());
151            })
152        });
153        recv.await.unwrap();
154        thread.join().unwrap();
155    }
156
157    #[tokio::test]
158    async fn test_spawn() {
159        let client = HttpClient::new(SpawnService::new(reqwest::Client::new(), Handle::current()));
160        test_client(client).await;
161    }
162
163    #[tokio::test]
164    #[should_panic]
165    async fn test_no_spawn() {
166        let client = HttpClient::new(reqwest::Client::new());
167        test_client(client).await;
168    }
169}