1use std::{fmt, net::SocketAddr, str};
30
31use actix_codec::Framed;
32pub use actix_http::ws::{CloseCode, CloseReason, Codec, Frame, Message};
33use actix_http::{ws, Payload, RequestHead};
34use actix_rt::time::timeout;
35use actix_service::Service as _;
36use base64::prelude::*;
37
38#[cfg(feature = "cookies")]
39use crate::cookie::{Cookie, CookieJar};
40use crate::{
41 client::ClientConfig,
42 connect::{BoxedSocket, ConnectRequest},
43 error::{HttpError, InvalidUrl, SendRequestError, WsClientError},
44 http::{
45 header::{self, HeaderName, HeaderValue, TryIntoHeaderValue, AUTHORIZATION},
46 ConnectionType, Method, StatusCode, Uri, Version,
47 },
48 ClientResponse,
49};
50
51pub struct WebsocketsRequest {
53 pub(crate) head: RequestHead,
54 err: Option<HttpError>,
55 origin: Option<HeaderValue>,
56 protocols: Option<String>,
57 addr: Option<SocketAddr>,
58 max_size: usize,
59 server_mode: bool,
60 config: ClientConfig,
61
62 #[cfg(feature = "cookies")]
63 cookies: Option<CookieJar>,
64}
65
66impl WebsocketsRequest {
67 pub(crate) fn new<U>(uri: U, config: ClientConfig) -> Self
69 where
70 Uri: TryFrom<U>,
71 <Uri as TryFrom<U>>::Error: Into<HttpError>,
72 {
73 let mut err = None;
74
75 #[allow(clippy::field_reassign_with_default)]
76 let mut head = {
77 let mut head = RequestHead::default();
78 head.method = Method::GET;
79 head.version = Version::HTTP_11;
80 head
81 };
82
83 match Uri::try_from(uri) {
84 Ok(uri) => head.uri = uri,
85 Err(error) => err = Some(error.into()),
86 }
87
88 WebsocketsRequest {
89 head,
90 err,
91 config,
92 addr: None,
93 origin: None,
94 protocols: None,
95 max_size: 65_536,
96 server_mode: false,
97 #[cfg(feature = "cookies")]
98 cookies: None,
99 }
100 }
101
102 pub fn address(mut self, addr: SocketAddr) -> Self {
107 self.addr = Some(addr);
108 self
109 }
110
111 pub fn protocols<U, V>(mut self, protos: U) -> Self
113 where
114 U: IntoIterator<Item = V>,
115 V: AsRef<str>,
116 {
117 let mut protos = protos
118 .into_iter()
119 .fold(String::new(), |acc, s| acc + s.as_ref() + ",");
120 protos.pop();
121 self.protocols = Some(protos);
122 self
123 }
124
125 #[cfg(feature = "cookies")]
127 pub fn cookie(mut self, cookie: Cookie<'_>) -> Self {
128 self.cookies
129 .get_or_insert_with(CookieJar::new)
130 .add(cookie.into_owned());
131 self
132 }
133
134 pub fn origin<V, E>(mut self, origin: V) -> Self
136 where
137 HeaderValue: TryFrom<V, Error = E>,
138 HttpError: From<E>,
139 {
140 match HeaderValue::try_from(origin) {
141 Ok(value) => self.origin = Some(value),
142 Err(err) => self.err = Some(err.into()),
143 }
144 self
145 }
146
147 pub fn max_frame_size(mut self, size: usize) -> Self {
151 self.max_size = size;
152 self
153 }
154
155 pub fn server_mode(mut self) -> Self {
157 self.server_mode = true;
158 self
159 }
160
161 pub fn header<K, V>(mut self, key: K, value: V) -> Self
166 where
167 HeaderName: TryFrom<K>,
168 <HeaderName as TryFrom<K>>::Error: Into<HttpError>,
169 V: TryIntoHeaderValue,
170 {
171 match HeaderName::try_from(key) {
172 Ok(key) => match value.try_into_value() {
173 Ok(value) => {
174 self.head.headers.append(key, value);
175 }
176 Err(err) => self.err = Some(err.into()),
177 },
178 Err(err) => self.err = Some(err.into()),
179 }
180 self
181 }
182
183 pub fn set_header<K, V>(mut self, key: K, value: V) -> Self
185 where
186 HeaderName: TryFrom<K>,
187 <HeaderName as TryFrom<K>>::Error: Into<HttpError>,
188 V: TryIntoHeaderValue,
189 {
190 match HeaderName::try_from(key) {
191 Ok(key) => match value.try_into_value() {
192 Ok(value) => {
193 self.head.headers.insert(key, value);
194 }
195 Err(err) => self.err = Some(err.into()),
196 },
197 Err(err) => self.err = Some(err.into()),
198 }
199 self
200 }
201
202 pub fn set_header_if_none<K, V>(mut self, key: K, value: V) -> Self
204 where
205 HeaderName: TryFrom<K>,
206 <HeaderName as TryFrom<K>>::Error: Into<HttpError>,
207 V: TryIntoHeaderValue,
208 {
209 match HeaderName::try_from(key) {
210 Ok(key) => {
211 if !self.head.headers.contains_key(&key) {
212 match value.try_into_value() {
213 Ok(value) => {
214 self.head.headers.insert(key, value);
215 }
216 Err(err) => self.err = Some(err.into()),
217 }
218 }
219 }
220 Err(err) => self.err = Some(err.into()),
221 }
222 self
223 }
224
225 pub fn basic_auth<U>(self, username: U, password: Option<&str>) -> Self
227 where
228 U: fmt::Display,
229 {
230 let auth = match password {
231 Some(password) => format!("{}:{}", username, password),
232 None => format!("{}:", username),
233 };
234 self.header(
235 AUTHORIZATION,
236 format!("Basic {}", BASE64_STANDARD.encode(auth)),
237 )
238 }
239
240 pub fn bearer_auth<T>(self, token: T) -> Self
242 where
243 T: fmt::Display,
244 {
245 self.header(AUTHORIZATION, format!("Bearer {}", token))
246 }
247
248 pub async fn connect(
250 mut self,
251 ) -> Result<(ClientResponse, Framed<BoxedSocket, Codec>), WsClientError> {
252 if let Some(err) = self.err.take() {
253 return Err(err.into());
254 }
255
256 let uri = &self.head.uri;
258
259 if uri.host().is_none() {
260 return Err(InvalidUrl::MissingHost.into());
261 } else if uri.scheme().is_none() {
262 return Err(InvalidUrl::MissingScheme.into());
263 } else if let Some(scheme) = uri.scheme() {
264 match scheme.as_str() {
265 "http" | "ws" | "https" | "wss" => {}
266 _ => return Err(InvalidUrl::UnknownScheme.into()),
267 }
268 } else {
269 return Err(InvalidUrl::UnknownScheme.into());
270 }
271
272 if !self.head.headers.contains_key(header::HOST) {
273 let hostname = uri.host().unwrap();
274 let port = uri.port();
275
276 self.head.headers.insert(
277 header::HOST,
278 HeaderValue::from_str(&Host { hostname, port }.to_string()).unwrap(),
279 );
280 }
281
282 #[cfg(feature = "cookies")]
284 if let Some(ref mut jar) = self.cookies {
285 let cookie: String = jar
286 .delta()
287 .map(|c| c.stripped().encoded().to_string())
289 .collect::<Vec<_>>()
290 .join("; ");
291
292 if !cookie.is_empty() {
293 self.head
294 .headers
295 .insert(header::COOKIE, HeaderValue::from_str(&cookie).unwrap());
296 }
297 }
298
299 if let Some(origin) = self.origin.take() {
301 self.head.headers.insert(header::ORIGIN, origin);
302 }
303
304 self.head.set_connection_type(ConnectionType::Upgrade);
305
306 #[allow(clippy::declare_interior_mutable_const)]
307 const HV_WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
308 self.head.headers.insert(header::UPGRADE, HV_WEBSOCKET);
309
310 #[allow(clippy::declare_interior_mutable_const)]
311 const HV_THIRTEEN: HeaderValue = HeaderValue::from_static("13");
312 self.head
313 .headers
314 .insert(header::SEC_WEBSOCKET_VERSION, HV_THIRTEEN);
315
316 if let Some(protocols) = self.protocols.take() {
317 self.head.headers.insert(
318 header::SEC_WEBSOCKET_PROTOCOL,
319 HeaderValue::try_from(protocols.as_str()).unwrap(),
320 );
321 }
322
323 let sec_key = rand::random::<[u8; 16]>();
326 let key = BASE64_STANDARD.encode(sec_key);
327
328 self.head.headers.insert(
329 header::SEC_WEBSOCKET_KEY,
330 HeaderValue::try_from(key.as_str()).unwrap(),
331 );
332
333 let head = self.head;
334 let max_size = self.max_size;
335 let server_mode = self.server_mode;
336
337 let req = ConnectRequest::Tunnel(head, self.addr);
338
339 let fut = self.config.connector.call(req);
340
341 let res = if let Some(to) = self.config.timeout {
343 timeout(to, fut)
344 .await
345 .map_err(|_| SendRequestError::Timeout)??
346 } else {
347 fut.await?
348 };
349
350 let (head, framed) = res.into_tunnel_response();
351
352 if head.status != StatusCode::SWITCHING_PROTOCOLS {
354 return Err(WsClientError::InvalidResponseStatus(head.status));
355 }
356
357 let has_hdr = if let Some(hdr) = head.headers.get(&header::UPGRADE) {
359 if let Ok(s) = hdr.to_str() {
360 s.to_ascii_lowercase().contains("websocket")
361 } else {
362 false
363 }
364 } else {
365 false
366 };
367 if !has_hdr {
368 log::trace!("Invalid upgrade header");
369 return Err(WsClientError::InvalidUpgradeHeader);
370 }
371
372 if let Some(conn) = head.headers.get(&header::CONNECTION) {
374 if let Ok(s) = conn.to_str() {
375 if !s.to_ascii_lowercase().contains("upgrade") {
376 log::trace!("Invalid connection header: {}", s);
377 return Err(WsClientError::InvalidConnectionHeader(conn.clone()));
378 }
379 } else {
380 log::trace!("Invalid connection header: {:?}", conn);
381 return Err(WsClientError::InvalidConnectionHeader(conn.clone()));
382 }
383 } else {
384 log::trace!("Missing connection header");
385 return Err(WsClientError::MissingConnectionHeader);
386 }
387
388 if let Some(hdr_key) = head.headers.get(&header::SEC_WEBSOCKET_ACCEPT) {
389 let encoded = ws::hash_key(key.as_ref());
390
391 if hdr_key.as_bytes() != encoded {
392 log::trace!(
393 "Invalid challenge response: expected: {:?} received: {:?}",
394 &encoded,
395 key
396 );
397
398 return Err(WsClientError::InvalidChallengeResponse(
399 encoded,
400 hdr_key.clone(),
401 ));
402 }
403 } else {
404 log::trace!("Missing SEC-WEBSOCKET-ACCEPT header");
405 return Err(WsClientError::MissingWebSocketAcceptHeader);
406 };
407
408 Ok((
410 ClientResponse::new(head, Payload::None),
411 framed.into_map_codec(|_| {
412 if server_mode {
413 ws::Codec::new().max_size(max_size)
414 } else {
415 ws::Codec::new().max_size(max_size).client_mode()
416 }
417 }),
418 ))
419 }
420}
421
422impl fmt::Debug for WebsocketsRequest {
423 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
424 writeln!(
425 f,
426 "\nWebsocketsRequest {}:{}",
427 self.head.method, self.head.uri
428 )?;
429 writeln!(f, " headers:")?;
430 for (key, val) in self.head.headers.iter() {
431 writeln!(f, " {:?}: {:?}", key, val)?;
432 }
433 Ok(())
434 }
435}
436
437struct Host<'a> {
439 hostname: &'a str,
440 port: Option<http::uri::Port<&'a str>>,
441}
442
443impl fmt::Display for Host<'_> {
444 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
445 f.write_str(self.hostname)?;
446
447 if let Some(port) = &self.port {
448 f.write_str(":")?;
449 f.write_str(port.as_str())?;
450 }
451
452 Ok(())
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459 use crate::Client;
460
461 #[actix_rt::test]
462 async fn test_debug() {
463 let request = Client::new().ws("/").header("x-test", "111");
464 let repr = format!("{:?}", request);
465 assert!(repr.contains("WebsocketsRequest"));
466 assert!(repr.contains("x-test"));
467 }
468
469 #[actix_rt::test]
470 async fn test_header_override() {
471 let req = Client::builder()
472 .add_default_header((header::CONTENT_TYPE, "111"))
473 .finish()
474 .ws("/")
475 .set_header(header::CONTENT_TYPE, "222");
476
477 assert_eq!(
478 req.head
479 .headers
480 .get(header::CONTENT_TYPE)
481 .unwrap()
482 .to_str()
483 .unwrap(),
484 "222"
485 );
486 }
487
488 #[actix_rt::test]
489 async fn basic_auth() {
490 let req = Client::new()
491 .ws("/")
492 .basic_auth("username", Some("password"));
493 assert_eq!(
494 req.head
495 .headers
496 .get(header::AUTHORIZATION)
497 .unwrap()
498 .to_str()
499 .unwrap(),
500 "Basic dXNlcm5hbWU6cGFzc3dvcmQ="
501 );
502
503 let req = Client::new().ws("/").basic_auth("username", None);
504 assert_eq!(
505 req.head
506 .headers
507 .get(header::AUTHORIZATION)
508 .unwrap()
509 .to_str()
510 .unwrap(),
511 "Basic dXNlcm5hbWU6"
512 );
513 }
514
515 #[actix_rt::test]
516 async fn bearer_auth() {
517 let req = Client::new().ws("/").bearer_auth("someS3cr3tAutht0k3n");
518 assert_eq!(
519 req.head
520 .headers
521 .get(header::AUTHORIZATION)
522 .unwrap()
523 .to_str()
524 .unwrap(),
525 "Bearer someS3cr3tAutht0k3n"
526 );
527
528 #[allow(clippy::let_underscore_future)]
529 let _ = req.connect();
530 }
531
532 #[actix_rt::test]
533 async fn basics() {
534 let req = Client::new()
535 .ws("http://localhost/")
536 .origin("test-origin")
537 .max_frame_size(100)
538 .server_mode()
539 .protocols(["v1", "v2"])
540 .set_header_if_none(header::CONTENT_TYPE, "json")
541 .set_header_if_none(header::CONTENT_TYPE, "text")
542 .cookie(Cookie::build("cookie1", "value1").finish());
543 assert_eq!(
544 req.origin.as_ref().unwrap().to_str().unwrap(),
545 "test-origin"
546 );
547 assert_eq!(req.max_size, 100);
548 assert!(req.server_mode);
549 assert_eq!(req.protocols, Some("v1,v2".to_string()));
550 assert_eq!(
551 req.head.headers.get(header::CONTENT_TYPE).unwrap(),
552 header::HeaderValue::from_static("json")
553 );
554
555 let _ = req.connect().await;
556
557 assert!(Client::new().ws("/").connect().await.is_err());
558 assert!(Client::new().ws("http:///test").connect().await.is_err());
559 assert!(Client::new().ws("hmm://test.com/").connect().await.is_err());
560 }
561}