Skip to main content

awc/
builder.rs

1use std::{fmt, net::IpAddr, rc::Rc, time::Duration};
2
3use actix_http::{
4    error::HttpError,
5    header::{self, HeaderMap, HeaderName, TryIntoHeaderPair},
6    Uri,
7};
8use actix_rt::net::{ActixStream, TcpStream};
9use actix_service::{boxed, Service};
10use base64::prelude::*;
11
12use crate::{
13    client::{
14        ClientConfig, ConnectInfo, Connector, ConnectorService, TcpConnectError, TcpConnection,
15    },
16    connect::DefaultConnector,
17    error::SendRequestError,
18    middleware::{NestTransform, Redirect, Transform},
19    Client, ConnectRequest, ConnectResponse,
20};
21
22/// An HTTP Client builder
23///
24/// This type can be used to construct an instance of `Client` through a
25/// builder-like pattern.
26pub struct ClientBuilder<S = (), M = ()> {
27    max_http_version: Option<http::Version>,
28    stream_window_size: Option<u32>,
29    conn_window_size: Option<u32>,
30    fundamental_headers: bool,
31    default_headers: HeaderMap,
32    timeout: Option<Duration>,
33    connector: Connector<S>,
34    middleware: M,
35    local_address: Option<IpAddr>,
36    max_redirects: u8,
37}
38
39impl ClientBuilder {
40    /// Create a new ClientBuilder with default settings
41    ///
42    /// Note: If the `rustls-0_23` feature is enabled and neither `rustls-0_23-native-roots` nor
43    /// `rustls-0_23-webpki-roots` are enabled, this ClientBuilder will build without TLS. In order
44    /// to enable TLS in this scenario, a custom `Connector` _must_ be added to the builder before
45    /// finishing construction.
46    #[allow(clippy::new_ret_no_self)]
47    pub fn new() -> ClientBuilder<
48        impl Service<
49                ConnectInfo<Uri>,
50                Response = TcpConnection<Uri, TcpStream>,
51                Error = TcpConnectError,
52            > + Clone,
53        (),
54    > {
55        ClientBuilder {
56            max_http_version: None,
57            stream_window_size: None,
58            conn_window_size: None,
59            fundamental_headers: true,
60            default_headers: HeaderMap::new(),
61            timeout: Some(Duration::from_secs(5)),
62            connector: Connector::new(),
63            middleware: (),
64            local_address: None,
65            max_redirects: 10,
66        }
67    }
68}
69
70impl<S, Io, M> ClientBuilder<S, M>
71where
72    S: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io>, Error = TcpConnectError>
73        + Clone
74        + 'static,
75    Io: ActixStream + fmt::Debug + 'static,
76{
77    /// Use custom connector service.
78    pub fn connector<S1, Io1>(self, connector: Connector<S1>) -> ClientBuilder<S1, M>
79    where
80        S1: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io1>, Error = TcpConnectError>
81            + Clone
82            + 'static,
83        Io1: ActixStream + fmt::Debug + 'static,
84    {
85        ClientBuilder {
86            middleware: self.middleware,
87            fundamental_headers: self.fundamental_headers,
88            default_headers: self.default_headers,
89            timeout: self.timeout,
90            local_address: self.local_address,
91            connector,
92            max_http_version: self.max_http_version,
93            stream_window_size: self.stream_window_size,
94            conn_window_size: self.conn_window_size,
95            max_redirects: self.max_redirects,
96        }
97    }
98
99    /// Set request timeout
100    ///
101    /// Request timeout is the total time before a response must be received.
102    /// Default value is 5 seconds.
103    pub fn timeout(mut self, timeout: Duration) -> Self {
104        self.timeout = Some(timeout);
105        self
106    }
107
108    /// Disable request timeout.
109    pub fn disable_timeout(mut self) -> Self {
110        self.timeout = None;
111        self
112    }
113
114    /// Set local IP Address the connector would use for establishing connection.
115    pub fn local_address(mut self, addr: IpAddr) -> Self {
116        self.local_address = Some(addr);
117        self
118    }
119
120    /// Maximum supported HTTP major version.
121    ///
122    /// Supported versions are HTTP/1.1 and HTTP/2.
123    pub fn max_http_version(mut self, val: http::Version) -> Self {
124        self.max_http_version = Some(val);
125        self
126    }
127
128    /// Do not follow redirects.
129    ///
130    /// Redirects are allowed by default.
131    pub fn disable_redirects(mut self) -> Self {
132        self.max_redirects = 0;
133        self
134    }
135
136    /// Set max number of redirects.
137    ///
138    /// Max redirects is set to 10 by default.
139    pub fn max_redirects(mut self, num: u8) -> Self {
140        self.max_redirects = num;
141        self
142    }
143
144    /// Indicates the initial window size (in octets) for
145    /// HTTP2 stream-level flow control for received data.
146    ///
147    /// The default value is 65,535 and is good for APIs, but not for big objects.
148    pub fn initial_window_size(mut self, size: u32) -> Self {
149        self.stream_window_size = Some(size);
150        self
151    }
152
153    /// Indicates the initial window size (in octets) for
154    /// HTTP2 connection-level flow control for received data.
155    ///
156    /// The default value is 65,535 and is good for APIs, but not for big objects.
157    pub fn initial_connection_window_size(mut self, size: u32) -> Self {
158        self.conn_window_size = Some(size);
159        self
160    }
161
162    /// Do not add fundamental default request headers.
163    ///
164    /// By default `Date` and `User-Agent` headers are set.
165    pub fn no_default_headers(mut self) -> Self {
166        self.fundamental_headers = false;
167        self
168    }
169
170    /// Add default header.
171    ///
172    /// Headers added by this method get added to every request unless overridden by other methods.
173    ///
174    /// # Panics
175    /// Panics if header name or value is invalid.
176    pub fn add_default_header(mut self, header: impl TryIntoHeaderPair) -> Self {
177        match header.try_into_pair() {
178            Ok((key, value)) => self.default_headers.append(key, value),
179            Err(err) => panic!("Header error: {:?}", err.into()),
180        }
181
182        self
183    }
184
185    #[doc(hidden)]
186    #[deprecated(since = "3.0.0", note = "Prefer `add_default_header((key, value))`.")]
187    pub fn header<K, V>(mut self, key: K, value: V) -> Self
188    where
189        HeaderName: TryFrom<K>,
190        <HeaderName as TryFrom<K>>::Error: fmt::Debug + Into<HttpError>,
191        V: header::TryIntoHeaderValue,
192        V::Error: fmt::Debug,
193    {
194        match HeaderName::try_from(key) {
195            Ok(key) => match value.try_into_value() {
196                Ok(value) => {
197                    self.default_headers.append(key, value);
198                }
199                Err(err) => log::error!("Header value error: {:?}", err),
200            },
201            Err(err) => log::error!("Header name error: {:?}", err),
202        }
203        self
204    }
205
206    /// Set client wide HTTP basic authorization header
207    pub fn basic_auth<N>(self, username: N, password: Option<&str>) -> Self
208    where
209        N: fmt::Display,
210    {
211        let auth = match password {
212            Some(password) => format!("{}:{}", username, password),
213            None => format!("{}:", username),
214        };
215        self.add_default_header((
216            header::AUTHORIZATION,
217            format!("Basic {}", BASE64_STANDARD.encode(auth)),
218        ))
219    }
220
221    /// Set client wide HTTP bearer authentication header
222    pub fn bearer_auth<T>(self, token: T) -> Self
223    where
224        T: fmt::Display,
225    {
226        self.add_default_header((header::AUTHORIZATION, format!("Bearer {}", token)))
227    }
228
229    /// Registers middleware, in the form of a middleware component (type), that runs during inbound
230    /// and/or outbound processing in the request life-cycle (request -> response),
231    /// modifying request/response as necessary, across all requests managed by the `Client`.
232    pub fn wrap<S1, M1>(self, mw: M1) -> ClientBuilder<S, NestTransform<M, M1, S1, ConnectRequest>>
233    where
234        M: Transform<S1, ConnectRequest>,
235        M1: Transform<M::Transform, ConnectRequest>,
236    {
237        ClientBuilder {
238            middleware: NestTransform::new(self.middleware, mw),
239            fundamental_headers: self.fundamental_headers,
240            max_http_version: self.max_http_version,
241            stream_window_size: self.stream_window_size,
242            conn_window_size: self.conn_window_size,
243            default_headers: self.default_headers,
244            timeout: self.timeout,
245            connector: self.connector,
246            local_address: self.local_address,
247            max_redirects: self.max_redirects,
248        }
249    }
250
251    /// Finish build process and create `Client` instance.
252    pub fn finish(self) -> Client
253    where
254        M: Transform<DefaultConnector<ConnectorService<S, Io>>, ConnectRequest> + 'static,
255        M::Transform: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError>,
256    {
257        let max_redirects = self.max_redirects;
258
259        if max_redirects > 0 {
260            self.wrap(Redirect::new().max_redirect_times(max_redirects))
261                ._finish()
262        } else {
263            self._finish()
264        }
265    }
266
267    fn _finish(self) -> Client
268    where
269        M: Transform<DefaultConnector<ConnectorService<S, Io>>, ConnectRequest> + 'static,
270        M::Transform: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError>,
271    {
272        let mut connector = self.connector;
273
274        if let Some(val) = self.max_http_version {
275            connector = connector.max_http_version(val);
276        };
277        if let Some(val) = self.conn_window_size {
278            connector = connector.initial_connection_window_size(val)
279        };
280        if let Some(val) = self.stream_window_size {
281            connector = connector.initial_window_size(val)
282        };
283        if let Some(val) = self.local_address {
284            connector = connector.local_address(val);
285        }
286
287        let connector = DefaultConnector::new(connector.finish());
288        let connector = boxed::rc_service(self.middleware.new_transform(connector));
289
290        Client(ClientConfig {
291            default_headers: Rc::new(self.default_headers),
292            timeout: self.timeout,
293            connector,
294        })
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn client_basic_auth() {
304        let client = ClientBuilder::new().basic_auth("username", Some("password"));
305        assert_eq!(
306            client
307                .default_headers
308                .get(header::AUTHORIZATION)
309                .unwrap()
310                .to_str()
311                .unwrap(),
312            "Basic dXNlcm5hbWU6cGFzc3dvcmQ="
313        );
314
315        let client = ClientBuilder::new().basic_auth("username", None);
316        assert_eq!(
317            client
318                .default_headers
319                .get(header::AUTHORIZATION)
320                .unwrap()
321                .to_str()
322                .unwrap(),
323            "Basic dXNlcm5hbWU6"
324        );
325    }
326
327    #[test]
328    fn client_bearer_auth() {
329        let client = ClientBuilder::new().bearer_auth("someS3cr3tAutht0k3n");
330        assert_eq!(
331            client
332                .default_headers
333                .get(header::AUTHORIZATION)
334                .unwrap()
335                .to_str()
336                .unwrap(),
337            "Bearer someS3cr3tAutht0k3n"
338        );
339    }
340}