Skip to main content

awc/responses/
json_body.rs

1use std::{
2    future::Future,
3    marker::PhantomData,
4    mem,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use actix_http::{error::PayloadError, header, HttpMessage};
10use bytes::Bytes;
11use futures_core::{ready, Stream};
12use pin_project_lite::pin_project;
13use serde::de::DeserializeOwned;
14
15use super::{read_body::ReadBody, ResponseTimeout, DEFAULT_BODY_LIMIT};
16use crate::{error::JsonPayloadError, ClientResponse};
17
18pin_project! {
19    /// A `Future` that reads a body stream, parses JSON, resolving to a deserialized `T`.
20    ///
21    /// # Errors
22    /// `Future` implementation returns error if:
23    /// - content type is not `application/json`;
24    /// - content length is greater than [limit](JsonBody::limit) (default: 2 MiB).
25    pub struct JsonBody<S, T> {
26        #[pin]
27        body: Option<ReadBody<S>>,
28        length: Option<usize>,
29        timeout: ResponseTimeout,
30        err: Option<JsonPayloadError>,
31        _phantom: PhantomData<T>,
32    }
33}
34
35impl<S, T> JsonBody<S, T>
36where
37    S: Stream<Item = Result<Bytes, PayloadError>>,
38    T: DeserializeOwned,
39{
40    /// Creates a JSON body stream reader from a response by taking its payload.
41    pub fn new(res: &mut ClientResponse<S>) -> Self {
42        // check content-type
43        let json = if let Ok(Some(mime)) = res.mime_type() {
44            mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON)
45        } else {
46            false
47        };
48
49        if !json {
50            return JsonBody {
51                length: None,
52                body: None,
53                timeout: ResponseTimeout::default(),
54                err: Some(JsonPayloadError::ContentType),
55                _phantom: PhantomData,
56            };
57        }
58
59        let length = res
60            .headers()
61            .get(&header::CONTENT_LENGTH)
62            .and_then(|len_hdr| len_hdr.to_str().ok())
63            .and_then(|len_str| len_str.parse::<usize>().ok());
64
65        JsonBody {
66            body: Some(ReadBody::new(res.take_payload(), DEFAULT_BODY_LIMIT)),
67            length,
68            timeout: mem::take(&mut res.timeout),
69            err: None,
70            _phantom: PhantomData,
71        }
72    }
73
74    /// Change max size of payload. Default limit is 2 MiB.
75    pub fn limit(mut self, limit: usize) -> Self {
76        if let Some(ref mut fut) = self.body {
77            fut.limit = limit;
78        }
79
80        self
81    }
82}
83
84impl<S, T> Future for JsonBody<S, T>
85where
86    S: Stream<Item = Result<Bytes, PayloadError>>,
87    T: DeserializeOwned,
88{
89    type Output = Result<T, JsonPayloadError>;
90
91    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
92        let this = self.project();
93
94        if let Some(err) = this.err.take() {
95            return Poll::Ready(Err(err));
96        }
97
98        if let Some(len) = this.length.take() {
99            let body = Option::as_ref(&this.body).unwrap();
100            if len > body.limit {
101                return Poll::Ready(Err(JsonPayloadError::Payload(PayloadError::Overflow)));
102            }
103        }
104
105        this.timeout
106            .poll_timeout(cx)
107            .map_err(JsonPayloadError::Payload)?;
108
109        let body = ready!(this.body.as_pin_mut().unwrap().poll(cx))?;
110        Poll::Ready(serde_json::from_slice::<T>(&body).map_err(JsonPayloadError::from))
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use actix_http::BoxedPayloadStream;
117    use serde::{Deserialize, Serialize};
118    use static_assertions::assert_impl_all;
119
120    use super::*;
121    use crate::test::TestResponse;
122
123    assert_impl_all!(JsonBody<BoxedPayloadStream, String>: Unpin);
124
125    #[derive(Serialize, Deserialize, PartialEq, Debug)]
126    struct MyObject {
127        name: String,
128    }
129
130    fn json_eq(err: JsonPayloadError, other: JsonPayloadError) -> bool {
131        match err {
132            JsonPayloadError::Payload(PayloadError::Overflow) => {
133                matches!(other, JsonPayloadError::Payload(PayloadError::Overflow))
134            }
135            JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType),
136            _ => false,
137        }
138    }
139
140    #[actix_rt::test]
141    async fn read_json_body() {
142        let mut req = TestResponse::default().finish();
143        let json = JsonBody::<_, MyObject>::new(&mut req).await;
144        assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
145
146        let mut req = TestResponse::default()
147            .insert_header((
148                header::CONTENT_TYPE,
149                header::HeaderValue::from_static("application/text"),
150            ))
151            .finish();
152        let json = JsonBody::<_, MyObject>::new(&mut req).await;
153        assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
154
155        let mut req = TestResponse::default()
156            .insert_header((
157                header::CONTENT_TYPE,
158                header::HeaderValue::from_static("application/json"),
159            ))
160            .insert_header((
161                header::CONTENT_LENGTH,
162                header::HeaderValue::from_static("10000"),
163            ))
164            .finish();
165
166        let json = JsonBody::<_, MyObject>::new(&mut req).limit(100).await;
167        assert!(json_eq(
168            json.err().unwrap(),
169            JsonPayloadError::Payload(PayloadError::Overflow)
170        ));
171
172        let mut req = TestResponse::default()
173            .insert_header((
174                header::CONTENT_TYPE,
175                header::HeaderValue::from_static("application/json"),
176            ))
177            .insert_header((
178                header::CONTENT_LENGTH,
179                header::HeaderValue::from_static("16"),
180            ))
181            .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
182            .finish();
183
184        let json = JsonBody::<_, MyObject>::new(&mut req).await;
185        assert_eq!(
186            json.ok().unwrap(),
187            MyObject {
188                name: "test".to_owned()
189            }
190        );
191    }
192}