1use std::{fmt, net::IpAddr, rc::Rc, time::Duration};
2
3use actix_http::{
4 error::HttpError,
5 header::{self, HeaderMap, HeaderName, TryIntoHeaderPair},
6 Uri,
7};
8use actix_rt::net::{ActixStream, TcpStream};
9use actix_service::{boxed, Service};
10use base64::prelude::*;
11
12use crate::{
13 client::{
14 ClientConfig, ConnectInfo, Connector, ConnectorService, TcpConnectError, TcpConnection,
15 },
16 connect::DefaultConnector,
17 error::SendRequestError,
18 middleware::{NestTransform, Redirect, Transform},
19 Client, ConnectRequest, ConnectResponse,
20};
21
22pub struct ClientBuilder<S = (), M = ()> {
27 max_http_version: Option<http::Version>,
28 stream_window_size: Option<u32>,
29 conn_window_size: Option<u32>,
30 fundamental_headers: bool,
31 default_headers: HeaderMap,
32 timeout: Option<Duration>,
33 connector: Connector<S>,
34 middleware: M,
35 local_address: Option<IpAddr>,
36 max_redirects: u8,
37}
38
39impl ClientBuilder {
40 #[allow(clippy::new_ret_no_self)]
47 pub fn new() -> ClientBuilder<
48 impl Service<
49 ConnectInfo<Uri>,
50 Response = TcpConnection<Uri, TcpStream>,
51 Error = TcpConnectError,
52 > + Clone,
53 (),
54 > {
55 ClientBuilder {
56 max_http_version: None,
57 stream_window_size: None,
58 conn_window_size: None,
59 fundamental_headers: true,
60 default_headers: HeaderMap::new(),
61 timeout: Some(Duration::from_secs(5)),
62 connector: Connector::new(),
63 middleware: (),
64 local_address: None,
65 max_redirects: 10,
66 }
67 }
68}
69
70impl<S, Io, M> ClientBuilder<S, M>
71where
72 S: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io>, Error = TcpConnectError>
73 + Clone
74 + 'static,
75 Io: ActixStream + fmt::Debug + 'static,
76{
77 pub fn connector<S1, Io1>(self, connector: Connector<S1>) -> ClientBuilder<S1, M>
79 where
80 S1: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io1>, Error = TcpConnectError>
81 + Clone
82 + 'static,
83 Io1: ActixStream + fmt::Debug + 'static,
84 {
85 ClientBuilder {
86 middleware: self.middleware,
87 fundamental_headers: self.fundamental_headers,
88 default_headers: self.default_headers,
89 timeout: self.timeout,
90 local_address: self.local_address,
91 connector,
92 max_http_version: self.max_http_version,
93 stream_window_size: self.stream_window_size,
94 conn_window_size: self.conn_window_size,
95 max_redirects: self.max_redirects,
96 }
97 }
98
99 pub fn timeout(mut self, timeout: Duration) -> Self {
104 self.timeout = Some(timeout);
105 self
106 }
107
108 pub fn disable_timeout(mut self) -> Self {
110 self.timeout = None;
111 self
112 }
113
114 pub fn local_address(mut self, addr: IpAddr) -> Self {
116 self.local_address = Some(addr);
117 self
118 }
119
120 pub fn max_http_version(mut self, val: http::Version) -> Self {
124 self.max_http_version = Some(val);
125 self
126 }
127
128 pub fn disable_redirects(mut self) -> Self {
132 self.max_redirects = 0;
133 self
134 }
135
136 pub fn max_redirects(mut self, num: u8) -> Self {
140 self.max_redirects = num;
141 self
142 }
143
144 pub fn initial_window_size(mut self, size: u32) -> Self {
149 self.stream_window_size = Some(size);
150 self
151 }
152
153 pub fn initial_connection_window_size(mut self, size: u32) -> Self {
158 self.conn_window_size = Some(size);
159 self
160 }
161
162 pub fn no_default_headers(mut self) -> Self {
166 self.fundamental_headers = false;
167 self
168 }
169
170 pub fn add_default_header(mut self, header: impl TryIntoHeaderPair) -> Self {
177 match header.try_into_pair() {
178 Ok((key, value)) => self.default_headers.append(key, value),
179 Err(err) => panic!("Header error: {:?}", err.into()),
180 }
181
182 self
183 }
184
185 #[doc(hidden)]
186 #[deprecated(since = "3.0.0", note = "Prefer `add_default_header((key, value))`.")]
187 pub fn header<K, V>(mut self, key: K, value: V) -> Self
188 where
189 HeaderName: TryFrom<K>,
190 <HeaderName as TryFrom<K>>::Error: fmt::Debug + Into<HttpError>,
191 V: header::TryIntoHeaderValue,
192 V::Error: fmt::Debug,
193 {
194 match HeaderName::try_from(key) {
195 Ok(key) => match value.try_into_value() {
196 Ok(value) => {
197 self.default_headers.append(key, value);
198 }
199 Err(err) => log::error!("Header value error: {:?}", err),
200 },
201 Err(err) => log::error!("Header name error: {:?}", err),
202 }
203 self
204 }
205
206 pub fn basic_auth<N>(self, username: N, password: Option<&str>) -> Self
208 where
209 N: fmt::Display,
210 {
211 let auth = match password {
212 Some(password) => format!("{}:{}", username, password),
213 None => format!("{}:", username),
214 };
215 self.add_default_header((
216 header::AUTHORIZATION,
217 format!("Basic {}", BASE64_STANDARD.encode(auth)),
218 ))
219 }
220
221 pub fn bearer_auth<T>(self, token: T) -> Self
223 where
224 T: fmt::Display,
225 {
226 self.add_default_header((header::AUTHORIZATION, format!("Bearer {}", token)))
227 }
228
229 pub fn wrap<S1, M1>(self, mw: M1) -> ClientBuilder<S, NestTransform<M, M1, S1, ConnectRequest>>
233 where
234 M: Transform<S1, ConnectRequest>,
235 M1: Transform<M::Transform, ConnectRequest>,
236 {
237 ClientBuilder {
238 middleware: NestTransform::new(self.middleware, mw),
239 fundamental_headers: self.fundamental_headers,
240 max_http_version: self.max_http_version,
241 stream_window_size: self.stream_window_size,
242 conn_window_size: self.conn_window_size,
243 default_headers: self.default_headers,
244 timeout: self.timeout,
245 connector: self.connector,
246 local_address: self.local_address,
247 max_redirects: self.max_redirects,
248 }
249 }
250
251 pub fn finish(self) -> Client
253 where
254 M: Transform<DefaultConnector<ConnectorService<S, Io>>, ConnectRequest> + 'static,
255 M::Transform: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError>,
256 {
257 let max_redirects = self.max_redirects;
258
259 if max_redirects > 0 {
260 self.wrap(Redirect::new().max_redirect_times(max_redirects))
261 ._finish()
262 } else {
263 self._finish()
264 }
265 }
266
267 fn _finish(self) -> Client
268 where
269 M: Transform<DefaultConnector<ConnectorService<S, Io>>, ConnectRequest> + 'static,
270 M::Transform: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError>,
271 {
272 let mut connector = self.connector;
273
274 if let Some(val) = self.max_http_version {
275 connector = connector.max_http_version(val);
276 };
277 if let Some(val) = self.conn_window_size {
278 connector = connector.initial_connection_window_size(val)
279 };
280 if let Some(val) = self.stream_window_size {
281 connector = connector.initial_window_size(val)
282 };
283 if let Some(val) = self.local_address {
284 connector = connector.local_address(val);
285 }
286
287 let connector = DefaultConnector::new(connector.finish());
288 let connector = boxed::rc_service(self.middleware.new_transform(connector));
289
290 Client(ClientConfig {
291 default_headers: Rc::new(self.default_headers),
292 timeout: self.timeout,
293 connector,
294 })
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn client_basic_auth() {
304 let client = ClientBuilder::new().basic_auth("username", Some("password"));
305 assert_eq!(
306 client
307 .default_headers
308 .get(header::AUTHORIZATION)
309 .unwrap()
310 .to_str()
311 .unwrap(),
312 "Basic dXNlcm5hbWU6cGFzc3dvcmQ="
313 );
314
315 let client = ClientBuilder::new().basic_auth("username", None);
316 assert_eq!(
317 client
318 .default_headers
319 .get(header::AUTHORIZATION)
320 .unwrap()
321 .to_str()
322 .unwrap(),
323 "Basic dXNlcm5hbWU6"
324 );
325 }
326
327 #[test]
328 fn client_bearer_auth() {
329 let client = ClientBuilder::new().bearer_auth("someS3cr3tAutht0k3n");
330 assert_eq!(
331 client
332 .default_headers
333 .get(header::AUTHORIZATION)
334 .unwrap()
335 .to_str()
336 .unwrap(),
337 "Bearer someS3cr3tAutht0k3n"
338 );
339 }
340}