Skip to main content

awc/
ws.rs

1//! Websockets client
2//!
3//! Type definitions required to use [`awc::Client`](super::Client) as a WebSocket client.
4//!
5//! # Examples
6//!
7//! ```no_run
8//! use awc::{Client, ws};
9//! use futures_util::{SinkExt as _, StreamExt as _};
10//!
11//! #[actix_rt::main]
12//! async fn main() {
13//!     let (_resp, mut connection) = Client::new()
14//!         .ws("ws://echo.websocket.org")
15//!         .connect()
16//!         .await
17//!         .unwrap();
18//!
19//!     connection
20//!         .send(ws::Message::Text("Echo".into()))
21//!         .await
22//!         .unwrap();
23//!     let response = connection.next().await.unwrap().unwrap();
24//!
25//!     assert_eq!(response, ws::Frame::Text("Echo".as_bytes().into()));
26//! }
27//! ```
28
29use std::{fmt, net::SocketAddr, str};
30
31use actix_codec::Framed;
32pub use actix_http::ws::{CloseCode, CloseReason, Codec, Frame, Message};
33use actix_http::{ws, Payload, RequestHead};
34use actix_rt::time::timeout;
35use actix_service::Service as _;
36use base64::prelude::*;
37
38#[cfg(feature = "cookies")]
39use crate::cookie::{Cookie, CookieJar};
40use crate::{
41    client::ClientConfig,
42    connect::{BoxedSocket, ConnectRequest},
43    error::{HttpError, InvalidUrl, SendRequestError, WsClientError},
44    http::{
45        header::{self, HeaderName, HeaderValue, TryIntoHeaderValue, AUTHORIZATION},
46        ConnectionType, Method, StatusCode, Uri, Version,
47    },
48    ClientResponse,
49};
50
51/// WebSocket connection.
52pub struct WebsocketsRequest {
53    pub(crate) head: RequestHead,
54    err: Option<HttpError>,
55    origin: Option<HeaderValue>,
56    protocols: Option<String>,
57    addr: Option<SocketAddr>,
58    max_size: usize,
59    server_mode: bool,
60    config: ClientConfig,
61
62    #[cfg(feature = "cookies")]
63    cookies: Option<CookieJar>,
64}
65
66impl WebsocketsRequest {
67    /// Create new WebSocket connection.
68    pub(crate) fn new<U>(uri: U, config: ClientConfig) -> Self
69    where
70        Uri: TryFrom<U>,
71        <Uri as TryFrom<U>>::Error: Into<HttpError>,
72    {
73        let mut err = None;
74
75        #[allow(clippy::field_reassign_with_default)]
76        let mut head = {
77            let mut head = RequestHead::default();
78            head.method = Method::GET;
79            head.version = Version::HTTP_11;
80            head
81        };
82
83        match Uri::try_from(uri) {
84            Ok(uri) => head.uri = uri,
85            Err(error) => err = Some(error.into()),
86        }
87
88        WebsocketsRequest {
89            head,
90            err,
91            config,
92            addr: None,
93            origin: None,
94            protocols: None,
95            max_size: 65_536,
96            server_mode: false,
97            #[cfg(feature = "cookies")]
98            cookies: None,
99        }
100    }
101
102    /// Set socket address of the server.
103    ///
104    /// This address is used for connection. If address is not
105    /// provided url's host name get resolved.
106    pub fn address(mut self, addr: SocketAddr) -> Self {
107        self.addr = Some(addr);
108        self
109    }
110
111    /// Set supported WebSocket protocols
112    pub fn protocols<U, V>(mut self, protos: U) -> Self
113    where
114        U: IntoIterator<Item = V>,
115        V: AsRef<str>,
116    {
117        let mut protos = protos
118            .into_iter()
119            .fold(String::new(), |acc, s| acc + s.as_ref() + ",");
120        protos.pop();
121        self.protocols = Some(protos);
122        self
123    }
124
125    /// Set a cookie
126    #[cfg(feature = "cookies")]
127    pub fn cookie(mut self, cookie: Cookie<'_>) -> Self {
128        self.cookies
129            .get_or_insert_with(CookieJar::new)
130            .add(cookie.into_owned());
131        self
132    }
133
134    /// Set request Origin
135    pub fn origin<V, E>(mut self, origin: V) -> Self
136    where
137        HeaderValue: TryFrom<V, Error = E>,
138        HttpError: From<E>,
139    {
140        match HeaderValue::try_from(origin) {
141            Ok(value) => self.origin = Some(value),
142            Err(err) => self.err = Some(err.into()),
143        }
144        self
145    }
146
147    /// Set max frame size
148    ///
149    /// By default max size is set to 64kB
150    pub fn max_frame_size(mut self, size: usize) -> Self {
151        self.max_size = size;
152        self
153    }
154
155    /// Disable payload masking. By default ws client masks frame payload.
156    pub fn server_mode(mut self) -> Self {
157        self.server_mode = true;
158        self
159    }
160
161    /// Append a header.
162    ///
163    /// Header gets appended to existing header.
164    /// To override header use `set_header()` method.
165    pub fn header<K, V>(mut self, key: K, value: V) -> Self
166    where
167        HeaderName: TryFrom<K>,
168        <HeaderName as TryFrom<K>>::Error: Into<HttpError>,
169        V: TryIntoHeaderValue,
170    {
171        match HeaderName::try_from(key) {
172            Ok(key) => match value.try_into_value() {
173                Ok(value) => {
174                    self.head.headers.append(key, value);
175                }
176                Err(err) => self.err = Some(err.into()),
177            },
178            Err(err) => self.err = Some(err.into()),
179        }
180        self
181    }
182
183    /// Insert a header, replaces existing header.
184    pub fn set_header<K, V>(mut self, key: K, value: V) -> Self
185    where
186        HeaderName: TryFrom<K>,
187        <HeaderName as TryFrom<K>>::Error: Into<HttpError>,
188        V: TryIntoHeaderValue,
189    {
190        match HeaderName::try_from(key) {
191            Ok(key) => match value.try_into_value() {
192                Ok(value) => {
193                    self.head.headers.insert(key, value);
194                }
195                Err(err) => self.err = Some(err.into()),
196            },
197            Err(err) => self.err = Some(err.into()),
198        }
199        self
200    }
201
202    /// Insert a header only if it is not yet set.
203    pub fn set_header_if_none<K, V>(mut self, key: K, value: V) -> Self
204    where
205        HeaderName: TryFrom<K>,
206        <HeaderName as TryFrom<K>>::Error: Into<HttpError>,
207        V: TryIntoHeaderValue,
208    {
209        match HeaderName::try_from(key) {
210            Ok(key) => {
211                if !self.head.headers.contains_key(&key) {
212                    match value.try_into_value() {
213                        Ok(value) => {
214                            self.head.headers.insert(key, value);
215                        }
216                        Err(err) => self.err = Some(err.into()),
217                    }
218                }
219            }
220            Err(err) => self.err = Some(err.into()),
221        }
222        self
223    }
224
225    /// Set HTTP basic authorization header
226    pub fn basic_auth<U>(self, username: U, password: Option<&str>) -> Self
227    where
228        U: fmt::Display,
229    {
230        let auth = match password {
231            Some(password) => format!("{}:{}", username, password),
232            None => format!("{}:", username),
233        };
234        self.header(
235            AUTHORIZATION,
236            format!("Basic {}", BASE64_STANDARD.encode(auth)),
237        )
238    }
239
240    /// Set HTTP bearer authentication header
241    pub fn bearer_auth<T>(self, token: T) -> Self
242    where
243        T: fmt::Display,
244    {
245        self.header(AUTHORIZATION, format!("Bearer {}", token))
246    }
247
248    /// Complete request construction and connect to a WebSocket server.
249    pub async fn connect(
250        mut self,
251    ) -> Result<(ClientResponse, Framed<BoxedSocket, Codec>), WsClientError> {
252        if let Some(err) = self.err.take() {
253            return Err(err.into());
254        }
255
256        // validate URI
257        let uri = &self.head.uri;
258
259        if uri.host().is_none() {
260            return Err(InvalidUrl::MissingHost.into());
261        } else if uri.scheme().is_none() {
262            return Err(InvalidUrl::MissingScheme.into());
263        } else if let Some(scheme) = uri.scheme() {
264            match scheme.as_str() {
265                "http" | "ws" | "https" | "wss" => {}
266                _ => return Err(InvalidUrl::UnknownScheme.into()),
267            }
268        } else {
269            return Err(InvalidUrl::UnknownScheme.into());
270        }
271
272        if !self.head.headers.contains_key(header::HOST) {
273            let hostname = uri.host().unwrap();
274            let port = uri.port();
275
276            self.head.headers.insert(
277                header::HOST,
278                HeaderValue::from_str(&Host { hostname, port }.to_string()).unwrap(),
279            );
280        }
281
282        // set cookies
283        #[cfg(feature = "cookies")]
284        if let Some(ref mut jar) = self.cookies {
285            let cookie: String = jar
286                .delta()
287                // ensure only name=value is written to cookie header
288                .map(|c| c.stripped().encoded().to_string())
289                .collect::<Vec<_>>()
290                .join("; ");
291
292            if !cookie.is_empty() {
293                self.head
294                    .headers
295                    .insert(header::COOKIE, HeaderValue::from_str(&cookie).unwrap());
296            }
297        }
298
299        // origin
300        if let Some(origin) = self.origin.take() {
301            self.head.headers.insert(header::ORIGIN, origin);
302        }
303
304        self.head.set_connection_type(ConnectionType::Upgrade);
305
306        #[allow(clippy::declare_interior_mutable_const)]
307        const HV_WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
308        self.head.headers.insert(header::UPGRADE, HV_WEBSOCKET);
309
310        #[allow(clippy::declare_interior_mutable_const)]
311        const HV_THIRTEEN: HeaderValue = HeaderValue::from_static("13");
312        self.head
313            .headers
314            .insert(header::SEC_WEBSOCKET_VERSION, HV_THIRTEEN);
315
316        if let Some(protocols) = self.protocols.take() {
317            self.head.headers.insert(
318                header::SEC_WEBSOCKET_PROTOCOL,
319                HeaderValue::try_from(protocols.as_str()).unwrap(),
320            );
321        }
322
323        // Generate a random key for the `Sec-WebSocket-Key` header which is a base64-encoded
324        // (see RFC 4648 §4) value that, when decoded, is 16 bytes in length (RFC 6455 §1.3).
325        let sec_key = rand::random::<[u8; 16]>();
326        let key = BASE64_STANDARD.encode(sec_key);
327
328        self.head.headers.insert(
329            header::SEC_WEBSOCKET_KEY,
330            HeaderValue::try_from(key.as_str()).unwrap(),
331        );
332
333        let head = self.head;
334        let max_size = self.max_size;
335        let server_mode = self.server_mode;
336
337        let req = ConnectRequest::Tunnel(head, self.addr);
338
339        let fut = self.config.connector.call(req);
340
341        // set request timeout
342        let res = if let Some(to) = self.config.timeout {
343            timeout(to, fut)
344                .await
345                .map_err(|_| SendRequestError::Timeout)??
346        } else {
347            fut.await?
348        };
349
350        let (head, framed) = res.into_tunnel_response();
351
352        // verify response
353        if head.status != StatusCode::SWITCHING_PROTOCOLS {
354            return Err(WsClientError::InvalidResponseStatus(head.status));
355        }
356
357        // check for "UPGRADE" to WebSocket header
358        let has_hdr = if let Some(hdr) = head.headers.get(&header::UPGRADE) {
359            if let Ok(s) = hdr.to_str() {
360                s.to_ascii_lowercase().contains("websocket")
361            } else {
362                false
363            }
364        } else {
365            false
366        };
367        if !has_hdr {
368            log::trace!("Invalid upgrade header");
369            return Err(WsClientError::InvalidUpgradeHeader);
370        }
371
372        // Check for "CONNECTION" header
373        if let Some(conn) = head.headers.get(&header::CONNECTION) {
374            if let Ok(s) = conn.to_str() {
375                if !s.to_ascii_lowercase().contains("upgrade") {
376                    log::trace!("Invalid connection header: {}", s);
377                    return Err(WsClientError::InvalidConnectionHeader(conn.clone()));
378                }
379            } else {
380                log::trace!("Invalid connection header: {:?}", conn);
381                return Err(WsClientError::InvalidConnectionHeader(conn.clone()));
382            }
383        } else {
384            log::trace!("Missing connection header");
385            return Err(WsClientError::MissingConnectionHeader);
386        }
387
388        if let Some(hdr_key) = head.headers.get(&header::SEC_WEBSOCKET_ACCEPT) {
389            let encoded = ws::hash_key(key.as_ref());
390
391            if hdr_key.as_bytes() != encoded {
392                log::trace!(
393                    "Invalid challenge response: expected: {:?} received: {:?}",
394                    &encoded,
395                    key
396                );
397
398                return Err(WsClientError::InvalidChallengeResponse(
399                    encoded,
400                    hdr_key.clone(),
401                ));
402            }
403        } else {
404            log::trace!("Missing SEC-WEBSOCKET-ACCEPT header");
405            return Err(WsClientError::MissingWebSocketAcceptHeader);
406        };
407
408        // response and ws framed
409        Ok((
410            ClientResponse::new(head, Payload::None),
411            framed.into_map_codec(|_| {
412                if server_mode {
413                    ws::Codec::new().max_size(max_size)
414                } else {
415                    ws::Codec::new().max_size(max_size).client_mode()
416                }
417            }),
418        ))
419    }
420}
421
422impl fmt::Debug for WebsocketsRequest {
423    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
424        writeln!(
425            f,
426            "\nWebsocketsRequest {}:{}",
427            self.head.method, self.head.uri
428        )?;
429        writeln!(f, "  headers:")?;
430        for (key, val) in self.head.headers.iter() {
431            writeln!(f, "    {:?}: {:?}", key, val)?;
432        }
433        Ok(())
434    }
435}
436
437/// Formatter for host (hostname+port) header values.
438struct Host<'a> {
439    hostname: &'a str,
440    port: Option<http::uri::Port<&'a str>>,
441}
442
443impl fmt::Display for Host<'_> {
444    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
445        f.write_str(self.hostname)?;
446
447        if let Some(port) = &self.port {
448            f.write_str(":")?;
449            f.write_str(port.as_str())?;
450        }
451
452        Ok(())
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459    use crate::Client;
460
461    #[actix_rt::test]
462    async fn test_debug() {
463        let request = Client::new().ws("/").header("x-test", "111");
464        let repr = format!("{:?}", request);
465        assert!(repr.contains("WebsocketsRequest"));
466        assert!(repr.contains("x-test"));
467    }
468
469    #[actix_rt::test]
470    async fn test_header_override() {
471        let req = Client::builder()
472            .add_default_header((header::CONTENT_TYPE, "111"))
473            .finish()
474            .ws("/")
475            .set_header(header::CONTENT_TYPE, "222");
476
477        assert_eq!(
478            req.head
479                .headers
480                .get(header::CONTENT_TYPE)
481                .unwrap()
482                .to_str()
483                .unwrap(),
484            "222"
485        );
486    }
487
488    #[actix_rt::test]
489    async fn basic_auth() {
490        let req = Client::new()
491            .ws("/")
492            .basic_auth("username", Some("password"));
493        assert_eq!(
494            req.head
495                .headers
496                .get(header::AUTHORIZATION)
497                .unwrap()
498                .to_str()
499                .unwrap(),
500            "Basic dXNlcm5hbWU6cGFzc3dvcmQ="
501        );
502
503        let req = Client::new().ws("/").basic_auth("username", None);
504        assert_eq!(
505            req.head
506                .headers
507                .get(header::AUTHORIZATION)
508                .unwrap()
509                .to_str()
510                .unwrap(),
511            "Basic dXNlcm5hbWU6"
512        );
513    }
514
515    #[actix_rt::test]
516    async fn bearer_auth() {
517        let req = Client::new().ws("/").bearer_auth("someS3cr3tAutht0k3n");
518        assert_eq!(
519            req.head
520                .headers
521                .get(header::AUTHORIZATION)
522                .unwrap()
523                .to_str()
524                .unwrap(),
525            "Bearer someS3cr3tAutht0k3n"
526        );
527
528        #[allow(clippy::let_underscore_future)]
529        let _ = req.connect();
530    }
531
532    #[actix_rt::test]
533    async fn basics() {
534        let req = Client::new()
535            .ws("http://localhost/")
536            .origin("test-origin")
537            .max_frame_size(100)
538            .server_mode()
539            .protocols(["v1", "v2"])
540            .set_header_if_none(header::CONTENT_TYPE, "json")
541            .set_header_if_none(header::CONTENT_TYPE, "text")
542            .cookie(Cookie::build("cookie1", "value1").finish());
543        assert_eq!(
544            req.origin.as_ref().unwrap().to_str().unwrap(),
545            "test-origin"
546        );
547        assert_eq!(req.max_size, 100);
548        assert!(req.server_mode);
549        assert_eq!(req.protocols, Some("v1,v2".to_string()));
550        assert_eq!(
551            req.head.headers.get(header::CONTENT_TYPE).unwrap(),
552            header::HeaderValue::from_static("json")
553        );
554
555        let _ = req.connect().await;
556
557        assert!(Client::new().ws("/").connect().await.is_err());
558        assert!(Client::new().ws("http:///test").connect().await.is_err());
559        assert!(Client::new().ws("hmm://test.com/").connect().await.is_err());
560    }
561}