Skip to main content

awc/middleware/
redirect.rs

1use std::{
2    future::Future,
3    net::SocketAddr,
4    pin::Pin,
5    rc::Rc,
6    task::{Context, Poll},
7};
8
9use actix_http::{header, Method, RequestHead, RequestHeadType, StatusCode, Uri};
10use actix_service::Service;
11use bytes::Bytes;
12use futures_core::ready;
13
14use super::Transform;
15use crate::{
16    any_body::AnyBody,
17    client::{InvalidUrl, SendRequestError},
18    connect::{ConnectRequest, ConnectResponse},
19    ClientResponse,
20};
21
22pub struct Redirect {
23    max_redirect_times: u8,
24}
25
26impl Default for Redirect {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32impl Redirect {
33    pub fn new() -> Self {
34        Self {
35            max_redirect_times: 10,
36        }
37    }
38
39    pub fn max_redirect_times(mut self, times: u8) -> Self {
40        self.max_redirect_times = times;
41        self
42    }
43}
44
45impl<S> Transform<S, ConnectRequest> for Redirect
46where
47    S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
48{
49    type Transform = RedirectService<S>;
50
51    fn new_transform(self, service: S) -> Self::Transform {
52        RedirectService {
53            max_redirect_times: self.max_redirect_times,
54            connector: Rc::new(service),
55        }
56    }
57}
58
59pub struct RedirectService<S> {
60    max_redirect_times: u8,
61    connector: Rc<S>,
62}
63
64impl<S> Service<ConnectRequest> for RedirectService<S>
65where
66    S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
67{
68    type Response = S::Response;
69    type Error = S::Error;
70    type Future = RedirectServiceFuture<S>;
71
72    actix_service::forward_ready!(connector);
73
74    fn call(&self, req: ConnectRequest) -> Self::Future {
75        match req {
76            ConnectRequest::Tunnel(head, addr) => {
77                let fut = self.connector.call(ConnectRequest::Tunnel(head, addr));
78                RedirectServiceFuture::Tunnel { fut }
79            }
80            ConnectRequest::Client(head, body, addr) => {
81                let connector = Rc::clone(&self.connector);
82                let max_redirect_times = self.max_redirect_times;
83
84                // backup the uri and method for reuse schema and authority.
85                let (uri, method, headers) = match head {
86                    RequestHeadType::Owned(ref head) => {
87                        (head.uri.clone(), head.method.clone(), head.headers.clone())
88                    }
89                    RequestHeadType::Rc(ref head, ..) => {
90                        (head.uri.clone(), head.method.clone(), head.headers.clone())
91                    }
92                };
93
94                let body_opt = match body {
95                    AnyBody::Bytes { ref body } => Some(body.clone()),
96                    _ => None,
97                };
98
99                let fut = connector.call(ConnectRequest::Client(head, body, addr));
100
101                RedirectServiceFuture::Client {
102                    fut,
103                    max_redirect_times,
104                    uri: Some(uri),
105                    method: Some(method),
106                    headers: Some(headers),
107                    body: body_opt,
108                    addr,
109                    connector: Some(connector),
110                }
111            }
112        }
113    }
114}
115
116pin_project_lite::pin_project! {
117    #[project = RedirectServiceProj]
118    pub enum RedirectServiceFuture<S>
119    where
120        S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError>,
121        S: 'static
122    {
123        Tunnel { #[pin] fut: S::Future },
124        Client {
125            #[pin]
126            fut: S::Future,
127            max_redirect_times: u8,
128            uri: Option<Uri>,
129            method: Option<Method>,
130            headers: Option<header::HeaderMap>,
131            body: Option<Bytes>,
132            addr: Option<SocketAddr>,
133            connector: Option<Rc<S>>,
134        }
135    }
136}
137
138impl<S> Future for RedirectServiceFuture<S>
139where
140    S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
141{
142    type Output = Result<ConnectResponse, SendRequestError>;
143
144    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
145        match self.as_mut().project() {
146            RedirectServiceProj::Tunnel { fut } => fut.poll(cx),
147            RedirectServiceProj::Client {
148                fut,
149                max_redirect_times,
150                uri,
151                method,
152                headers,
153                body,
154                addr,
155                connector,
156            } => match ready!(fut.poll(cx))? {
157                ConnectResponse::Client(res) => match res.head().status {
158                    StatusCode::MOVED_PERMANENTLY
159                    | StatusCode::FOUND
160                    | StatusCode::SEE_OTHER
161                    | StatusCode::TEMPORARY_REDIRECT
162                    | StatusCode::PERMANENT_REDIRECT
163                        if *max_redirect_times > 0
164                            && res.headers().contains_key(header::LOCATION) =>
165                    {
166                        let reuse_body = res.head().status == StatusCode::TEMPORARY_REDIRECT
167                            || res.head().status == StatusCode::PERMANENT_REDIRECT;
168
169                        let prev_uri = uri.take().unwrap();
170
171                        // rebuild uri from the location header value.
172                        let next_uri = build_next_uri(&res, &prev_uri)?;
173
174                        // take ownership of states that could be reused
175                        let addr = addr.take();
176                        let connector = connector.take();
177
178                        // reset method
179                        let method = if reuse_body {
180                            method.take().unwrap()
181                        } else {
182                            let method = method.take().unwrap();
183                            match method {
184                                Method::GET | Method::HEAD => method,
185                                _ => Method::GET,
186                            }
187                        };
188
189                        let mut body = body.take();
190                        let body_new = if reuse_body {
191                            // try to reuse saved body
192                            match body {
193                                Some(ref bytes) => AnyBody::Bytes {
194                                    body: bytes.clone(),
195                                },
196
197                                // body was a non-reusable type so send an empty body instead
198                                _ => AnyBody::empty(),
199                            }
200                        } else {
201                            body = None;
202                            // remove body since we're downgrading to a GET
203                            AnyBody::None
204                        };
205
206                        let mut headers = headers.take().unwrap();
207
208                        remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
209
210                        // use a new request head.
211                        let mut head = RequestHead::default();
212                        head.uri = next_uri.clone();
213                        head.method = method.clone();
214                        head.headers = headers.clone();
215
216                        let head = RequestHeadType::Owned(head);
217
218                        let mut max_redirect_times = *max_redirect_times;
219                        max_redirect_times -= 1;
220
221                        let fut = connector
222                            .as_ref()
223                            .unwrap()
224                            .call(ConnectRequest::Client(head, body_new, addr));
225
226                        self.set(RedirectServiceFuture::Client {
227                            fut,
228                            max_redirect_times,
229                            uri: Some(next_uri),
230                            method: Some(method),
231                            headers: Some(headers),
232                            body,
233                            addr,
234                            connector,
235                        });
236
237                        self.poll(cx)
238                    }
239                    _ => Poll::Ready(Ok(ConnectResponse::Client(res))),
240                },
241                _ => unreachable!("ConnectRequest::Tunnel is not handled by Redirect"),
242            },
243        }
244    }
245}
246
247fn build_next_uri(res: &ClientResponse, prev_uri: &Uri) -> Result<Uri, SendRequestError> {
248    // responses without this header are not processed
249    let location = res.headers().get(header::LOCATION).unwrap();
250
251    // try to parse the location and resolve to a full URI but fall back to default if it fails
252    let uri = Uri::try_from(location.as_bytes()).unwrap_or_else(|_| Uri::default());
253
254    let uri = if uri.scheme().is_none() || uri.authority().is_none() {
255        let builder = Uri::builder()
256            .scheme(prev_uri.scheme().cloned().unwrap())
257            .authority(prev_uri.authority().cloned().unwrap());
258
259        // scheme-relative address
260        if location.as_bytes().starts_with(b"//") {
261            let scheme = prev_uri.scheme_str().unwrap();
262            let mut full_url: Vec<u8> = scheme.as_bytes().to_vec();
263            full_url.push(b':');
264            full_url.extend(location.as_bytes());
265
266            return Uri::try_from(full_url)
267                .map_err(|_| SendRequestError::Url(InvalidUrl::MissingScheme));
268        }
269        // when scheme or authority is missing treat the location value as path and query
270        // recover error where location does not have leading slash
271        let path = if location.as_bytes().starts_with(b"/") {
272            location.as_bytes().to_owned()
273        } else {
274            [b"/", location.as_bytes()].concat()
275        };
276
277        builder
278            .path_and_query(path)
279            .build()
280            .map_err(|err| SendRequestError::Url(InvalidUrl::HttpError(err)))?
281    } else {
282        uri
283    };
284
285    Ok(uri)
286}
287
288fn remove_sensitive_headers(headers: &mut header::HeaderMap, prev_uri: &Uri, next_uri: &Uri) {
289    if next_uri.host() != prev_uri.host()
290        || next_uri.port() != prev_uri.port()
291        || next_uri.scheme() != prev_uri.scheme()
292    {
293        headers.remove(header::COOKIE);
294        headers.remove(header::AUTHORIZATION);
295        headers.remove(header::PROXY_AUTHORIZATION);
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use std::str::FromStr;
302
303    use actix_web::{web, App, Error, HttpRequest, HttpResponse};
304
305    use super::*;
306    use crate::{http::header::HeaderValue, ClientBuilder};
307
308    #[actix_rt::test]
309    async fn basic_redirect() {
310        let client = ClientBuilder::new()
311            .disable_redirects()
312            .wrap(Redirect::new().max_redirect_times(10))
313            .finish();
314
315        let srv = actix_test::start(|| {
316            App::new()
317                .service(web::resource("/test").route(web::to(|| async {
318                    Ok::<_, Error>(HttpResponse::BadRequest())
319                })))
320                .service(web::resource("/").route(web::to(|| async {
321                    Ok::<_, Error>(
322                        HttpResponse::Found()
323                            .append_header(("location", "/test"))
324                            .finish(),
325                    )
326                })))
327        });
328
329        let res = client.get(srv.url("/")).send().await.unwrap();
330
331        assert_eq!(res.status().as_u16(), 400);
332    }
333
334    #[actix_rt::test]
335    async fn redirect_relative_without_leading_slash() {
336        let client = ClientBuilder::new().finish();
337
338        let srv = actix_test::start(|| {
339            App::new()
340                .service(web::resource("/").route(web::to(|| async {
341                    HttpResponse::Found()
342                        .insert_header(("location", "abc/"))
343                        .finish()
344                })))
345                .service(
346                    web::resource("/abc/")
347                        .route(web::to(|| async { HttpResponse::Accepted().finish() })),
348                )
349        });
350
351        let res = client.get(srv.url("/")).send().await.unwrap();
352        assert_eq!(res.status(), StatusCode::ACCEPTED);
353    }
354
355    #[actix_rt::test]
356    async fn redirect_without_location() {
357        let client = ClientBuilder::new()
358            .disable_redirects()
359            .wrap(Redirect::new().max_redirect_times(10))
360            .finish();
361
362        let srv = actix_test::start(|| {
363            App::new().service(web::resource("/").route(web::to(|| async {
364                Ok::<_, Error>(HttpResponse::Found().finish())
365            })))
366        });
367
368        let res = client.get(srv.url("/")).send().await.unwrap();
369        assert_eq!(res.status(), StatusCode::FOUND);
370    }
371
372    #[actix_rt::test]
373    async fn test_redirect_limit() {
374        let client = ClientBuilder::new()
375            .disable_redirects()
376            .wrap(Redirect::new().max_redirect_times(1))
377            .connector(crate::Connector::new())
378            .finish();
379
380        let srv = actix_test::start(|| {
381            App::new()
382                .service(web::resource("/").route(web::to(|| async {
383                    Ok::<_, Error>(
384                        HttpResponse::Found()
385                            .insert_header(("location", "/test"))
386                            .finish(),
387                    )
388                })))
389                .service(web::resource("/test").route(web::to(|| async {
390                    Ok::<_, Error>(
391                        HttpResponse::Found()
392                            .insert_header(("location", "/test2"))
393                            .finish(),
394                    )
395                })))
396                .service(web::resource("/test2").route(web::to(|| async {
397                    Ok::<_, Error>(HttpResponse::BadRequest())
398                })))
399        });
400
401        let res = client.get(srv.url("/")).send().await.unwrap();
402        assert_eq!(res.status(), StatusCode::FOUND);
403        assert_eq!(
404            res.headers()
405                .get(header::LOCATION)
406                .unwrap()
407                .to_str()
408                .unwrap(),
409            "/test2"
410        );
411    }
412
413    #[actix_rt::test]
414    async fn test_redirect_status_kind_307_308() {
415        let srv = actix_test::start(|| {
416            async fn root() -> HttpResponse {
417                HttpResponse::TemporaryRedirect()
418                    .append_header(("location", "/test"))
419                    .finish()
420            }
421
422            async fn test(req: HttpRequest, body: Bytes) -> HttpResponse {
423                if req.method() == Method::POST && !body.is_empty() {
424                    HttpResponse::Ok().finish()
425                } else {
426                    HttpResponse::InternalServerError().finish()
427                }
428            }
429
430            App::new()
431                .service(web::resource("/").route(web::to(root)))
432                .service(web::resource("/test").route(web::to(test)))
433        });
434
435        let res = srv.post("/").send_body("Hello").await.unwrap();
436        assert_eq!(res.status().as_u16(), 200);
437    }
438
439    #[actix_rt::test]
440    async fn test_redirect_status_kind_301_302_303() {
441        let srv = actix_test::start(|| {
442            async fn root() -> HttpResponse {
443                HttpResponse::Found()
444                    .append_header(("location", "/test"))
445                    .finish()
446            }
447
448            async fn test(req: HttpRequest, body: Bytes) -> HttpResponse {
449                if (req.method() == Method::GET || req.method() == Method::HEAD) && body.is_empty()
450                {
451                    HttpResponse::Ok().finish()
452                } else {
453                    HttpResponse::InternalServerError().finish()
454                }
455            }
456
457            App::new()
458                .service(web::resource("/").route(web::to(root)))
459                .service(web::resource("/test").route(web::to(test)))
460        });
461
462        let res = srv.post("/").send_body("Hello").await.unwrap();
463        assert_eq!(res.status().as_u16(), 200);
464
465        let res = srv.post("/").send().await.unwrap();
466        assert_eq!(res.status().as_u16(), 200);
467    }
468
469    #[actix_rt::test]
470    async fn test_redirect_headers() {
471        let srv = actix_test::start(|| {
472            async fn root(req: HttpRequest) -> HttpResponse {
473                if req
474                    .headers()
475                    .get("custom")
476                    .unwrap_or(&HeaderValue::from_str("").unwrap())
477                    == "value"
478                {
479                    HttpResponse::Found()
480                        .append_header(("location", "/test"))
481                        .finish()
482                } else {
483                    HttpResponse::InternalServerError().finish()
484                }
485            }
486
487            async fn test(req: HttpRequest) -> HttpResponse {
488                if req
489                    .headers()
490                    .get("custom")
491                    .unwrap_or(&HeaderValue::from_str("").unwrap())
492                    == "value"
493                {
494                    HttpResponse::Ok().finish()
495                } else {
496                    HttpResponse::InternalServerError().finish()
497                }
498            }
499
500            App::new()
501                .service(web::resource("/").route(web::to(root)))
502                .service(web::resource("/test").route(web::to(test)))
503        });
504
505        let client = ClientBuilder::new()
506            .add_default_header(("custom", "value"))
507            .disable_redirects()
508            .finish();
509        let res = client.get(srv.url("/")).send().await.unwrap();
510        assert_eq!(res.status().as_u16(), 302);
511
512        let client = ClientBuilder::new()
513            .add_default_header(("custom", "value"))
514            .finish();
515        let res = client.get(srv.url("/")).send().await.unwrap();
516        assert_eq!(res.status().as_u16(), 200);
517
518        let client = ClientBuilder::new().finish();
519        let res = client
520            .get(srv.url("/"))
521            .insert_header(("custom", "value"))
522            .send()
523            .await
524            .unwrap();
525        assert_eq!(res.status().as_u16(), 200);
526    }
527
528    #[actix_rt::test]
529    async fn test_redirect_cross_origin_headers() {
530        // defining two services to have two different origins
531        let srv2 = actix_test::start(|| {
532            async fn root(req: HttpRequest) -> HttpResponse {
533                if req.headers().get(header::AUTHORIZATION).is_none() {
534                    HttpResponse::Ok().finish()
535                } else {
536                    HttpResponse::InternalServerError().finish()
537                }
538            }
539
540            App::new().service(web::resource("/").route(web::to(root)))
541        });
542        let srv2_port: u16 = srv2.addr().port();
543
544        let srv1 = actix_test::start(move || {
545            async fn root(req: HttpRequest) -> HttpResponse {
546                let port = *req.app_data::<u16>().unwrap();
547                if req.headers().get(header::AUTHORIZATION).is_some() {
548                    HttpResponse::Found()
549                        .append_header(("location", format!("http://localhost:{}/", port).as_str()))
550                        .finish()
551                } else {
552                    HttpResponse::InternalServerError().finish()
553                }
554            }
555
556            async fn test1(req: HttpRequest) -> HttpResponse {
557                if req.headers().get(header::AUTHORIZATION).is_some() {
558                    HttpResponse::Found()
559                        .append_header(("location", "/test2"))
560                        .finish()
561                } else {
562                    HttpResponse::InternalServerError().finish()
563                }
564            }
565
566            async fn test2(req: HttpRequest) -> HttpResponse {
567                if req.headers().get(header::AUTHORIZATION).is_some() {
568                    HttpResponse::Ok().finish()
569                } else {
570                    HttpResponse::InternalServerError().finish()
571                }
572            }
573
574            App::new()
575                .app_data(srv2_port)
576                .service(web::resource("/").route(web::to(root)))
577                .service(web::resource("/test1").route(web::to(test1)))
578                .service(web::resource("/test2").route(web::to(test2)))
579        });
580
581        // send a request to different origins, http://srv1/ then http://srv2/. So it should remove the header
582        let client = ClientBuilder::new()
583            .add_default_header((header::AUTHORIZATION, "auth_key_value"))
584            .finish();
585        let res = client.get(srv1.url("/")).send().await.unwrap();
586        assert_eq!(res.status().as_u16(), 200);
587
588        // send a request to same origin, http://srv1/test1 then http://srv1/test2. So it should NOT remove any header
589        let res = client.get(srv1.url("/test1")).send().await.unwrap();
590        assert_eq!(res.status().as_u16(), 200);
591    }
592
593    #[actix_rt::test]
594    async fn test_double_slash_redirect() {
595        let client = ClientBuilder::new()
596            .disable_redirects()
597            .wrap(Redirect::new().max_redirect_times(10))
598            .finish();
599
600        let srv = actix_test::start(|| {
601            App::new()
602                .service(web::resource("/test").route(web::to(|| async {
603                    Ok::<_, Error>(HttpResponse::BadRequest())
604                })))
605                .service(
606                    web::resource("/").route(web::to(|req: HttpRequest| async move {
607                        Ok::<_, Error>(
608                            HttpResponse::Found()
609                                .append_header((
610                                    "location",
611                                    format!(
612                                        "//localhost:{}/test",
613                                        req.app_config().local_addr().port()
614                                    )
615                                    .as_str(),
616                                ))
617                                .finish(),
618                        )
619                    })),
620                )
621        });
622
623        let res = client.get(srv.url("/")).send().await.unwrap();
624
625        assert_eq!(res.status().as_u16(), 400);
626    }
627
628    #[actix_rt::test]
629    async fn test_remove_sensitive_headers() {
630        fn gen_headers() -> header::HeaderMap {
631            let mut headers = header::HeaderMap::new();
632            headers.insert(header::USER_AGENT, HeaderValue::from_str("value").unwrap());
633            headers.insert(
634                header::AUTHORIZATION,
635                HeaderValue::from_str("value").unwrap(),
636            );
637            headers.insert(
638                header::PROXY_AUTHORIZATION,
639                HeaderValue::from_str("value").unwrap(),
640            );
641            headers.insert(header::COOKIE, HeaderValue::from_str("value").unwrap());
642            headers
643        }
644
645        // Same origin
646        let prev_uri = Uri::from_str("https://host/path1").unwrap();
647        let next_uri = Uri::from_str("https://host/path2").unwrap();
648        let mut headers = gen_headers();
649        remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
650        assert_eq!(headers.len(), 4);
651
652        // different schema
653        let prev_uri = Uri::from_str("http://host/").unwrap();
654        let next_uri = Uri::from_str("https://host/").unwrap();
655        let mut headers = gen_headers();
656        remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
657        assert_eq!(headers.len(), 1);
658
659        // different host
660        let prev_uri = Uri::from_str("https://host1/").unwrap();
661        let next_uri = Uri::from_str("https://host2/").unwrap();
662        let mut headers = gen_headers();
663        remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
664        assert_eq!(headers.len(), 1);
665
666        // different port
667        let prev_uri = Uri::from_str("https://host:12/").unwrap();
668        let next_uri = Uri::from_str("https://host:23/").unwrap();
669        let mut headers = gen_headers();
670        remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
671        assert_eq!(headers.len(), 1);
672
673        // different everything!
674        let prev_uri = Uri::from_str("http://host1:12/path1").unwrap();
675        let next_uri = Uri::from_str("https://host2:23/path2").unwrap();
676        let mut headers = gen_headers();
677        remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
678        assert_eq!(headers.len(), 1);
679    }
680}