awc/responses/
json_body.rs1use std::{
2 future::Future,
3 marker::PhantomData,
4 mem,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use actix_http::{error::PayloadError, header, HttpMessage};
10use bytes::Bytes;
11use futures_core::{ready, Stream};
12use pin_project_lite::pin_project;
13use serde::de::DeserializeOwned;
14
15use super::{read_body::ReadBody, ResponseTimeout, DEFAULT_BODY_LIMIT};
16use crate::{error::JsonPayloadError, ClientResponse};
17
18pin_project! {
19 pub struct JsonBody<S, T> {
26 #[pin]
27 body: Option<ReadBody<S>>,
28 length: Option<usize>,
29 timeout: ResponseTimeout,
30 err: Option<JsonPayloadError>,
31 _phantom: PhantomData<T>,
32 }
33}
34
35impl<S, T> JsonBody<S, T>
36where
37 S: Stream<Item = Result<Bytes, PayloadError>>,
38 T: DeserializeOwned,
39{
40 pub fn new(res: &mut ClientResponse<S>) -> Self {
42 let json = if let Ok(Some(mime)) = res.mime_type() {
44 mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON)
45 } else {
46 false
47 };
48
49 if !json {
50 return JsonBody {
51 length: None,
52 body: None,
53 timeout: ResponseTimeout::default(),
54 err: Some(JsonPayloadError::ContentType),
55 _phantom: PhantomData,
56 };
57 }
58
59 let length = res
60 .headers()
61 .get(&header::CONTENT_LENGTH)
62 .and_then(|len_hdr| len_hdr.to_str().ok())
63 .and_then(|len_str| len_str.parse::<usize>().ok());
64
65 JsonBody {
66 body: Some(ReadBody::new(res.take_payload(), DEFAULT_BODY_LIMIT)),
67 length,
68 timeout: mem::take(&mut res.timeout),
69 err: None,
70 _phantom: PhantomData,
71 }
72 }
73
74 pub fn limit(mut self, limit: usize) -> Self {
76 if let Some(ref mut fut) = self.body {
77 fut.limit = limit;
78 }
79
80 self
81 }
82}
83
84impl<S, T> Future for JsonBody<S, T>
85where
86 S: Stream<Item = Result<Bytes, PayloadError>>,
87 T: DeserializeOwned,
88{
89 type Output = Result<T, JsonPayloadError>;
90
91 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
92 let this = self.project();
93
94 if let Some(err) = this.err.take() {
95 return Poll::Ready(Err(err));
96 }
97
98 if let Some(len) = this.length.take() {
99 let body = Option::as_ref(&this.body).unwrap();
100 if len > body.limit {
101 return Poll::Ready(Err(JsonPayloadError::Payload(PayloadError::Overflow)));
102 }
103 }
104
105 this.timeout
106 .poll_timeout(cx)
107 .map_err(JsonPayloadError::Payload)?;
108
109 let body = ready!(this.body.as_pin_mut().unwrap().poll(cx))?;
110 Poll::Ready(serde_json::from_slice::<T>(&body).map_err(JsonPayloadError::from))
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use actix_http::BoxedPayloadStream;
117 use serde::{Deserialize, Serialize};
118 use static_assertions::assert_impl_all;
119
120 use super::*;
121 use crate::test::TestResponse;
122
123 assert_impl_all!(JsonBody<BoxedPayloadStream, String>: Unpin);
124
125 #[derive(Serialize, Deserialize, PartialEq, Debug)]
126 struct MyObject {
127 name: String,
128 }
129
130 fn json_eq(err: JsonPayloadError, other: JsonPayloadError) -> bool {
131 match err {
132 JsonPayloadError::Payload(PayloadError::Overflow) => {
133 matches!(other, JsonPayloadError::Payload(PayloadError::Overflow))
134 }
135 JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType),
136 _ => false,
137 }
138 }
139
140 #[actix_rt::test]
141 async fn read_json_body() {
142 let mut req = TestResponse::default().finish();
143 let json = JsonBody::<_, MyObject>::new(&mut req).await;
144 assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
145
146 let mut req = TestResponse::default()
147 .insert_header((
148 header::CONTENT_TYPE,
149 header::HeaderValue::from_static("application/text"),
150 ))
151 .finish();
152 let json = JsonBody::<_, MyObject>::new(&mut req).await;
153 assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
154
155 let mut req = TestResponse::default()
156 .insert_header((
157 header::CONTENT_TYPE,
158 header::HeaderValue::from_static("application/json"),
159 ))
160 .insert_header((
161 header::CONTENT_LENGTH,
162 header::HeaderValue::from_static("10000"),
163 ))
164 .finish();
165
166 let json = JsonBody::<_, MyObject>::new(&mut req).limit(100).await;
167 assert!(json_eq(
168 json.err().unwrap(),
169 JsonPayloadError::Payload(PayloadError::Overflow)
170 ));
171
172 let mut req = TestResponse::default()
173 .insert_header((
174 header::CONTENT_TYPE,
175 header::HeaderValue::from_static("application/json"),
176 ))
177 .insert_header((
178 header::CONTENT_LENGTH,
179 header::HeaderValue::from_static("16"),
180 ))
181 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
182 .finish();
183
184 let json = JsonBody::<_, MyObject>::new(&mut req).await;
185 assert_eq!(
186 json.ok().unwrap(),
187 MyObject {
188 name: "test".to_owned()
189 }
190 );
191 }
192}