1use std::{fmt, net, rc::Rc, time::Duration};
2
3use actix_http::{
4 body::MessageBody,
5 error::HttpError,
6 header::{self, HeaderMap, HeaderValue, TryIntoHeaderPair},
7 ConnectionType, Method, RequestHead, Uri, Version,
8};
9use base64::prelude::*;
10use bytes::Bytes;
11use futures_core::Stream;
12use serde::Serialize;
13
14#[cfg(feature = "cookies")]
15use crate::cookie::{Cookie, CookieJar};
16use crate::{
17 client::ClientConfig,
18 error::{FreezeRequestError, InvalidUrl},
19 frozen::FrozenClientRequest,
20 sender::{PrepForSendingError, RequestSender, SendClientRequest},
21 BoxError,
22};
23
24pub struct ClientRequest {
45 pub(crate) head: RequestHead,
46 err: Option<HttpError>,
47 addr: Option<net::SocketAddr>,
48 response_decompress: bool,
49 timeout: Option<Duration>,
50 config: ClientConfig,
51
52 #[cfg(feature = "cookies")]
53 cookies: Option<CookieJar>,
54}
55
56impl ClientRequest {
57 pub(crate) fn new<U>(method: Method, uri: U, config: ClientConfig) -> Self
59 where
60 Uri: TryFrom<U>,
61 <Uri as TryFrom<U>>::Error: Into<HttpError>,
62 {
63 ClientRequest {
64 config,
65 head: RequestHead::default(),
66 err: None,
67 addr: None,
68 #[cfg(feature = "cookies")]
69 cookies: None,
70 timeout: None,
71 response_decompress: true,
72 }
73 .method(method)
74 .uri(uri)
75 }
76
77 #[inline]
79 pub fn uri<U>(mut self, uri: U) -> Self
80 where
81 Uri: TryFrom<U>,
82 <Uri as TryFrom<U>>::Error: Into<HttpError>,
83 {
84 match Uri::try_from(uri) {
85 Ok(uri) => self.head.uri = uri,
86 Err(err) => self.err = Some(err.into()),
87 }
88 self
89 }
90
91 pub fn get_uri(&self) -> &Uri {
93 &self.head.uri
94 }
95
96 pub fn address(mut self, addr: net::SocketAddr) -> Self {
101 self.addr = Some(addr);
102 self
103 }
104
105 #[inline]
107 pub fn method(mut self, method: Method) -> Self {
108 self.head.method = method;
109 self
110 }
111
112 pub fn get_method(&self) -> &Method {
114 &self.head.method
115 }
116
117 #[doc(hidden)]
121 #[inline]
122 pub fn version(mut self, version: Version) -> Self {
123 self.head.version = version;
124 self
125 }
126
127 pub fn get_version(&self) -> &Version {
129 &self.head.version
130 }
131
132 pub fn get_peer_addr(&self) -> &Option<net::SocketAddr> {
134 &self.head.peer_addr
135 }
136
137 #[inline]
139 pub fn headers(&self) -> &HeaderMap {
140 &self.head.headers
141 }
142
143 #[inline]
145 pub fn headers_mut(&mut self) -> &mut HeaderMap {
146 &mut self.head.headers
147 }
148
149 pub fn insert_header(mut self, header: impl TryIntoHeaderPair) -> Self {
151 match header.try_into_pair() {
152 Ok((key, value)) => {
153 self.head.headers.insert(key, value);
154 }
155 Err(err) => self.err = Some(err.into()),
156 };
157
158 self
159 }
160
161 pub fn insert_header_if_none(mut self, header: impl TryIntoHeaderPair) -> Self {
163 match header.try_into_pair() {
164 Ok((key, value)) => {
165 if !self.head.headers.contains_key(&key) {
166 self.head.headers.insert(key, value);
167 }
168 }
169 Err(err) => self.err = Some(err.into()),
170 };
171
172 self
173 }
174
175 pub fn append_header(mut self, header: impl TryIntoHeaderPair) -> Self {
186 match header.try_into_pair() {
187 Ok((key, value)) => self.head.headers.append(key, value),
188 Err(err) => self.err = Some(err.into()),
189 };
190
191 self
192 }
193
194 #[inline]
196 pub fn camel_case(mut self) -> Self {
197 self.head.set_camel_case_headers(true);
198 self
199 }
200
201 #[inline]
204 pub fn force_close(mut self) -> Self {
205 self.head.set_connection_type(ConnectionType::Close);
206 self
207 }
208
209 #[inline]
211 pub fn content_type<V>(mut self, value: V) -> Self
212 where
213 HeaderValue: TryFrom<V>,
214 <HeaderValue as TryFrom<V>>::Error: Into<HttpError>,
215 {
216 match HeaderValue::try_from(value) {
217 Ok(value) => {
218 self.head.headers.insert(header::CONTENT_TYPE, value);
219 }
220 Err(err) => self.err = Some(err.into()),
221 }
222 self
223 }
224
225 #[inline]
227 pub fn content_length(self, len: u64) -> Self {
228 let mut buf = itoa::Buffer::new();
229 self.insert_header((header::CONTENT_LENGTH, buf.format(len)))
230 }
231
232 pub fn basic_auth(self, username: impl fmt::Display, password: impl fmt::Display) -> Self {
236 let auth = format!("{}:{}", username, password);
237
238 self.insert_header((
239 header::AUTHORIZATION,
240 format!("Basic {}", BASE64_STANDARD.encode(auth)),
241 ))
242 }
243
244 pub fn bearer_auth(self, token: impl fmt::Display) -> Self {
246 self.insert_header((header::AUTHORIZATION, format!("Bearer {}", token)))
247 }
248
249 #[cfg(feature = "cookies")]
265 pub fn cookie(mut self, cookie: Cookie<'_>) -> Self {
266 self.cookies
267 .get_or_insert_with(CookieJar::new)
268 .add(cookie.into_owned());
269 self
270 }
271
272 pub fn no_decompress(mut self) -> Self {
274 self.response_decompress = false;
275 self
276 }
277
278 pub fn timeout(mut self, timeout: Duration) -> Self {
283 self.timeout = Some(timeout);
284 self
285 }
286
287 pub fn query<T: Serialize>(mut self, query: &T) -> Result<Self, serde_urlencoded::ser::Error> {
289 let mut parts = self.head.uri.clone().into_parts();
290
291 if let Some(path_and_query) = parts.path_and_query {
292 let query = serde_urlencoded::to_string(query)?;
293 let path = path_and_query.path();
294 parts.path_and_query = format!("{}?{}", path, query).parse().ok();
295
296 match Uri::from_parts(parts) {
297 Ok(uri) => self.head.uri = uri,
298 Err(err) => self.err = Some(err.into()),
299 }
300 }
301
302 Ok(self)
303 }
304
305 pub fn freeze(self) -> Result<FrozenClientRequest, FreezeRequestError> {
308 let slf = self.prep_for_sending()?;
309
310 let request = FrozenClientRequest {
311 head: Rc::new(slf.head),
312 addr: slf.addr,
313 response_decompress: slf.response_decompress,
314 timeout: slf.timeout,
315 config: slf.config,
316 };
317
318 Ok(request)
319 }
320
321 pub fn send_body<B>(self, body: B) -> SendClientRequest
323 where
324 B: MessageBody + 'static,
325 {
326 let slf = match self.prep_for_sending() {
327 Ok(slf) => slf,
328 Err(err) => return err.into(),
329 };
330
331 RequestSender::Owned(slf.head).send_body(
332 slf.addr,
333 slf.response_decompress,
334 slf.timeout,
335 &slf.config,
336 body,
337 )
338 }
339
340 pub fn send_json<T: Serialize>(self, value: &T) -> SendClientRequest {
342 let slf = match self.prep_for_sending() {
343 Ok(slf) => slf,
344 Err(err) => return err.into(),
345 };
346
347 RequestSender::Owned(slf.head).send_json(
348 slf.addr,
349 slf.response_decompress,
350 slf.timeout,
351 &slf.config,
352 value,
353 )
354 }
355
356 pub fn send_form<T: Serialize>(self, value: &T) -> SendClientRequest {
360 let slf = match self.prep_for_sending() {
361 Ok(slf) => slf,
362 Err(err) => return err.into(),
363 };
364
365 RequestSender::Owned(slf.head).send_form(
366 slf.addr,
367 slf.response_decompress,
368 slf.timeout,
369 &slf.config,
370 value,
371 )
372 }
373
374 pub fn send_stream<S, E>(self, stream: S) -> SendClientRequest
376 where
377 S: Stream<Item = Result<Bytes, E>> + 'static,
378 E: Into<BoxError> + 'static,
379 {
380 let slf = match self.prep_for_sending() {
381 Ok(slf) => slf,
382 Err(err) => return err.into(),
383 };
384
385 RequestSender::Owned(slf.head).send_stream(
386 slf.addr,
387 slf.response_decompress,
388 slf.timeout,
389 &slf.config,
390 stream,
391 )
392 }
393
394 pub fn send(self) -> SendClientRequest {
396 let slf = match self.prep_for_sending() {
397 Ok(slf) => slf,
398 Err(err) => return err.into(),
399 };
400
401 RequestSender::Owned(slf.head).send(
402 slf.addr,
403 slf.response_decompress,
404 slf.timeout,
405 &slf.config,
406 )
407 }
408
409 fn prep_for_sending(#[allow(unused_mut)] mut self) -> Result<Self, PrepForSendingError> {
411 if let Some(err) = self.err {
412 return Err(err.into());
413 }
414
415 let uri = &self.head.uri;
417 if uri.host().is_none() {
418 return Err(InvalidUrl::MissingHost.into());
419 } else if uri.scheme().is_none() {
420 return Err(InvalidUrl::MissingScheme.into());
421 } else if let Some(scheme) = uri.scheme() {
422 match scheme.as_str() {
423 "http" | "ws" | "https" | "wss" => {}
424 _ => return Err(InvalidUrl::UnknownScheme.into()),
425 }
426 } else {
427 return Err(InvalidUrl::UnknownScheme.into());
428 }
429
430 #[cfg(feature = "cookies")]
432 if let Some(ref mut jar) = self.cookies {
433 let cookie: String = jar
434 .delta()
435 .map(|c| c.stripped().encoded().to_string())
437 .collect::<Vec<_>>()
438 .join("; ");
439
440 if !cookie.is_empty() {
441 self.head
442 .headers
443 .insert(header::COOKIE, HeaderValue::from_str(&cookie).unwrap());
444 }
445 }
446
447 let mut slf = self;
448
449 if slf.response_decompress {
453 #[allow(clippy::vec_init_then_push)]
455 #[cfg(feature = "__compress")]
456 let accept_encoding = {
457 let mut encoding = vec![];
458
459 #[cfg(feature = "compress-brotli")]
460 {
461 encoding.push("br");
462 }
463
464 #[cfg(feature = "compress-gzip")]
465 {
466 encoding.push("gzip");
467 encoding.push("deflate");
468 }
469
470 #[cfg(feature = "compress-zstd")]
471 encoding.push("zstd");
472
473 assert!(
474 !encoding.is_empty(),
475 "encoding can not be empty unless __compress feature has been explicitly enabled"
476 );
477
478 encoding.join(", ")
479 };
480
481 #[cfg(not(feature = "__compress"))]
484 let accept_encoding = "identity";
485
486 slf = slf.insert_header_if_none((header::ACCEPT_ENCODING, accept_encoding));
487 }
488
489 Ok(slf)
490 }
491}
492
493impl fmt::Debug for ClientRequest {
494 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
495 writeln!(
496 f,
497 "\nClientRequest {:?} {} {}",
498 self.head.version, self.head.method, self.head.uri
499 )?;
500 writeln!(f, " headers:")?;
501 for (key, val) in self.head.headers.iter() {
502 writeln!(f, " {:?}: {:?}", key, val)?;
503 }
504 Ok(())
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use std::time::SystemTime;
511
512 use actix_http::header::HttpDate;
513
514 use super::*;
515 use crate::Client;
516
517 #[actix_rt::test]
518 async fn test_debug() {
519 let request = Client::new().get("/").append_header(("x-test", "111"));
520 let repr = format!("{:?}", request);
521 assert!(repr.contains("ClientRequest"));
522 assert!(repr.contains("x-test"));
523 }
524
525 #[actix_rt::test]
526 async fn test_basics() {
527 let req = Client::new()
528 .put("/")
529 .version(Version::HTTP_2)
530 .insert_header((header::DATE, HttpDate::from(SystemTime::now())))
531 .content_type("plain/text")
532 .append_header((header::SERVER, "awc"));
533
534 let req = if let Some(val) = Some("server") {
535 req.append_header((header::USER_AGENT, val))
536 } else {
537 req
538 };
539
540 let req = if let Some(_val) = Option::<&str>::None {
541 req.append_header((header::ALLOW, "1"))
542 } else {
543 req
544 };
545
546 let mut req = req.content_length(100);
547
548 assert!(req.headers().contains_key(header::CONTENT_TYPE));
549 assert!(req.headers().contains_key(header::DATE));
550 assert!(req.headers().contains_key(header::SERVER));
551 assert!(req.headers().contains_key(header::USER_AGENT));
552 assert!(!req.headers().contains_key(header::ALLOW));
553 assert!(!req.headers().contains_key(header::EXPECT));
554 assert_eq!(req.head.version, Version::HTTP_2);
555
556 let _ = req.headers_mut();
557
558 #[allow(clippy::let_underscore_future)]
559 let _ = req.send_body("");
560 }
561
562 #[actix_rt::test]
563 async fn test_client_header() {
564 let req = Client::builder()
565 .add_default_header((header::CONTENT_TYPE, "111"))
566 .finish()
567 .get("/");
568
569 assert_eq!(
570 req.head
571 .headers
572 .get(header::CONTENT_TYPE)
573 .unwrap()
574 .to_str()
575 .unwrap(),
576 "111"
577 );
578 }
579
580 #[actix_rt::test]
581 async fn test_client_header_override() {
582 let req = Client::builder()
583 .add_default_header((header::CONTENT_TYPE, "111"))
584 .finish()
585 .get("/")
586 .insert_header((header::CONTENT_TYPE, "222"));
587
588 assert_eq!(
589 req.head
590 .headers
591 .get(header::CONTENT_TYPE)
592 .unwrap()
593 .to_str()
594 .unwrap(),
595 "222"
596 );
597 }
598
599 #[actix_rt::test]
600 async fn client_basic_auth() {
601 let req = Client::new().get("/").basic_auth("username", "password");
602 assert_eq!(
603 req.head
604 .headers
605 .get(header::AUTHORIZATION)
606 .unwrap()
607 .to_str()
608 .unwrap(),
609 "Basic dXNlcm5hbWU6cGFzc3dvcmQ="
610 );
611
612 let req = Client::new().get("/").basic_auth("username", "");
613 assert_eq!(
614 req.head
615 .headers
616 .get(header::AUTHORIZATION)
617 .unwrap()
618 .to_str()
619 .unwrap(),
620 "Basic dXNlcm5hbWU6"
621 );
622 }
623
624 #[actix_rt::test]
625 async fn client_bearer_auth() {
626 let req = Client::new().get("/").bearer_auth("someS3cr3tAutht0k3n");
627 assert_eq!(
628 req.head
629 .headers
630 .get(header::AUTHORIZATION)
631 .unwrap()
632 .to_str()
633 .unwrap(),
634 "Bearer someS3cr3tAutht0k3n"
635 );
636 }
637
638 #[actix_rt::test]
639 async fn client_query() {
640 let req = Client::new()
641 .get("/")
642 .query(&[("key1", "val1"), ("key2", "val2")])
643 .unwrap();
644 assert_eq!(req.get_uri().query().unwrap(), "key1=val1&key2=val2");
645 }
646}