1use std::{
2 future::Future,
3 net::SocketAddr,
4 pin::Pin,
5 rc::Rc,
6 task::{Context, Poll},
7};
8
9use actix_http::{header, Method, RequestHead, RequestHeadType, StatusCode, Uri};
10use actix_service::Service;
11use bytes::Bytes;
12use futures_core::ready;
13
14use super::Transform;
15use crate::{
16 any_body::AnyBody,
17 client::{InvalidUrl, SendRequestError},
18 connect::{ConnectRequest, ConnectResponse},
19 ClientResponse,
20};
21
22pub struct Redirect {
23 max_redirect_times: u8,
24}
25
26impl Default for Redirect {
27 fn default() -> Self {
28 Self::new()
29 }
30}
31
32impl Redirect {
33 pub fn new() -> Self {
34 Self {
35 max_redirect_times: 10,
36 }
37 }
38
39 pub fn max_redirect_times(mut self, times: u8) -> Self {
40 self.max_redirect_times = times;
41 self
42 }
43}
44
45impl<S> Transform<S, ConnectRequest> for Redirect
46where
47 S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
48{
49 type Transform = RedirectService<S>;
50
51 fn new_transform(self, service: S) -> Self::Transform {
52 RedirectService {
53 max_redirect_times: self.max_redirect_times,
54 connector: Rc::new(service),
55 }
56 }
57}
58
59pub struct RedirectService<S> {
60 max_redirect_times: u8,
61 connector: Rc<S>,
62}
63
64impl<S> Service<ConnectRequest> for RedirectService<S>
65where
66 S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
67{
68 type Response = S::Response;
69 type Error = S::Error;
70 type Future = RedirectServiceFuture<S>;
71
72 actix_service::forward_ready!(connector);
73
74 fn call(&self, req: ConnectRequest) -> Self::Future {
75 match req {
76 ConnectRequest::Tunnel(head, addr) => {
77 let fut = self.connector.call(ConnectRequest::Tunnel(head, addr));
78 RedirectServiceFuture::Tunnel { fut }
79 }
80 ConnectRequest::Client(head, body, addr) => {
81 let connector = Rc::clone(&self.connector);
82 let max_redirect_times = self.max_redirect_times;
83
84 let (uri, method, headers) = match head {
86 RequestHeadType::Owned(ref head) => {
87 (head.uri.clone(), head.method.clone(), head.headers.clone())
88 }
89 RequestHeadType::Rc(ref head, ..) => {
90 (head.uri.clone(), head.method.clone(), head.headers.clone())
91 }
92 };
93
94 let body_opt = match body {
95 AnyBody::Bytes { ref body } => Some(body.clone()),
96 _ => None,
97 };
98
99 let fut = connector.call(ConnectRequest::Client(head, body, addr));
100
101 RedirectServiceFuture::Client {
102 fut,
103 max_redirect_times,
104 uri: Some(uri),
105 method: Some(method),
106 headers: Some(headers),
107 body: body_opt,
108 addr,
109 connector: Some(connector),
110 }
111 }
112 }
113 }
114}
115
116pin_project_lite::pin_project! {
117 #[project = RedirectServiceProj]
118 pub enum RedirectServiceFuture<S>
119 where
120 S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError>,
121 S: 'static
122 {
123 Tunnel { #[pin] fut: S::Future },
124 Client {
125 #[pin]
126 fut: S::Future,
127 max_redirect_times: u8,
128 uri: Option<Uri>,
129 method: Option<Method>,
130 headers: Option<header::HeaderMap>,
131 body: Option<Bytes>,
132 addr: Option<SocketAddr>,
133 connector: Option<Rc<S>>,
134 }
135 }
136}
137
138impl<S> Future for RedirectServiceFuture<S>
139where
140 S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
141{
142 type Output = Result<ConnectResponse, SendRequestError>;
143
144 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
145 match self.as_mut().project() {
146 RedirectServiceProj::Tunnel { fut } => fut.poll(cx),
147 RedirectServiceProj::Client {
148 fut,
149 max_redirect_times,
150 uri,
151 method,
152 headers,
153 body,
154 addr,
155 connector,
156 } => match ready!(fut.poll(cx))? {
157 ConnectResponse::Client(res) => match res.head().status {
158 StatusCode::MOVED_PERMANENTLY
159 | StatusCode::FOUND
160 | StatusCode::SEE_OTHER
161 | StatusCode::TEMPORARY_REDIRECT
162 | StatusCode::PERMANENT_REDIRECT
163 if *max_redirect_times > 0
164 && res.headers().contains_key(header::LOCATION) =>
165 {
166 let reuse_body = res.head().status == StatusCode::TEMPORARY_REDIRECT
167 || res.head().status == StatusCode::PERMANENT_REDIRECT;
168
169 let prev_uri = uri.take().unwrap();
170
171 let next_uri = build_next_uri(&res, &prev_uri)?;
173
174 let addr = addr.take();
176 let connector = connector.take();
177
178 let method = if reuse_body {
180 method.take().unwrap()
181 } else {
182 let method = method.take().unwrap();
183 match method {
184 Method::GET | Method::HEAD => method,
185 _ => Method::GET,
186 }
187 };
188
189 let mut body = body.take();
190 let body_new = if reuse_body {
191 match body {
193 Some(ref bytes) => AnyBody::Bytes {
194 body: bytes.clone(),
195 },
196
197 _ => AnyBody::empty(),
199 }
200 } else {
201 body = None;
202 AnyBody::None
204 };
205
206 let mut headers = headers.take().unwrap();
207
208 remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
209
210 let mut head = RequestHead::default();
212 head.uri = next_uri.clone();
213 head.method = method.clone();
214 head.headers = headers.clone();
215
216 let head = RequestHeadType::Owned(head);
217
218 let mut max_redirect_times = *max_redirect_times;
219 max_redirect_times -= 1;
220
221 let fut = connector
222 .as_ref()
223 .unwrap()
224 .call(ConnectRequest::Client(head, body_new, addr));
225
226 self.set(RedirectServiceFuture::Client {
227 fut,
228 max_redirect_times,
229 uri: Some(next_uri),
230 method: Some(method),
231 headers: Some(headers),
232 body,
233 addr,
234 connector,
235 });
236
237 self.poll(cx)
238 }
239 _ => Poll::Ready(Ok(ConnectResponse::Client(res))),
240 },
241 _ => unreachable!("ConnectRequest::Tunnel is not handled by Redirect"),
242 },
243 }
244 }
245}
246
247fn build_next_uri(res: &ClientResponse, prev_uri: &Uri) -> Result<Uri, SendRequestError> {
248 let location = res.headers().get(header::LOCATION).unwrap();
250
251 let uri = Uri::try_from(location.as_bytes()).unwrap_or_else(|_| Uri::default());
253
254 let uri = if uri.scheme().is_none() || uri.authority().is_none() {
255 let builder = Uri::builder()
256 .scheme(prev_uri.scheme().cloned().unwrap())
257 .authority(prev_uri.authority().cloned().unwrap());
258
259 if location.as_bytes().starts_with(b"//") {
261 let scheme = prev_uri.scheme_str().unwrap();
262 let mut full_url: Vec<u8> = scheme.as_bytes().to_vec();
263 full_url.push(b':');
264 full_url.extend(location.as_bytes());
265
266 return Uri::try_from(full_url)
267 .map_err(|_| SendRequestError::Url(InvalidUrl::MissingScheme));
268 }
269 let path = if location.as_bytes().starts_with(b"/") {
272 location.as_bytes().to_owned()
273 } else {
274 [b"/", location.as_bytes()].concat()
275 };
276
277 builder
278 .path_and_query(path)
279 .build()
280 .map_err(|err| SendRequestError::Url(InvalidUrl::HttpError(err)))?
281 } else {
282 uri
283 };
284
285 Ok(uri)
286}
287
288fn remove_sensitive_headers(headers: &mut header::HeaderMap, prev_uri: &Uri, next_uri: &Uri) {
289 if next_uri.host() != prev_uri.host()
290 || next_uri.port() != prev_uri.port()
291 || next_uri.scheme() != prev_uri.scheme()
292 {
293 headers.remove(header::COOKIE);
294 headers.remove(header::AUTHORIZATION);
295 headers.remove(header::PROXY_AUTHORIZATION);
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use std::str::FromStr;
302
303 use actix_web::{web, App, Error, HttpRequest, HttpResponse};
304
305 use super::*;
306 use crate::{http::header::HeaderValue, ClientBuilder};
307
308 #[actix_rt::test]
309 async fn basic_redirect() {
310 let client = ClientBuilder::new()
311 .disable_redirects()
312 .wrap(Redirect::new().max_redirect_times(10))
313 .finish();
314
315 let srv = actix_test::start(|| {
316 App::new()
317 .service(web::resource("/test").route(web::to(|| async {
318 Ok::<_, Error>(HttpResponse::BadRequest())
319 })))
320 .service(web::resource("/").route(web::to(|| async {
321 Ok::<_, Error>(
322 HttpResponse::Found()
323 .append_header(("location", "/test"))
324 .finish(),
325 )
326 })))
327 });
328
329 let res = client.get(srv.url("/")).send().await.unwrap();
330
331 assert_eq!(res.status().as_u16(), 400);
332 }
333
334 #[actix_rt::test]
335 async fn redirect_relative_without_leading_slash() {
336 let client = ClientBuilder::new().finish();
337
338 let srv = actix_test::start(|| {
339 App::new()
340 .service(web::resource("/").route(web::to(|| async {
341 HttpResponse::Found()
342 .insert_header(("location", "abc/"))
343 .finish()
344 })))
345 .service(
346 web::resource("/abc/")
347 .route(web::to(|| async { HttpResponse::Accepted().finish() })),
348 )
349 });
350
351 let res = client.get(srv.url("/")).send().await.unwrap();
352 assert_eq!(res.status(), StatusCode::ACCEPTED);
353 }
354
355 #[actix_rt::test]
356 async fn redirect_without_location() {
357 let client = ClientBuilder::new()
358 .disable_redirects()
359 .wrap(Redirect::new().max_redirect_times(10))
360 .finish();
361
362 let srv = actix_test::start(|| {
363 App::new().service(web::resource("/").route(web::to(|| async {
364 Ok::<_, Error>(HttpResponse::Found().finish())
365 })))
366 });
367
368 let res = client.get(srv.url("/")).send().await.unwrap();
369 assert_eq!(res.status(), StatusCode::FOUND);
370 }
371
372 #[actix_rt::test]
373 async fn test_redirect_limit() {
374 let client = ClientBuilder::new()
375 .disable_redirects()
376 .wrap(Redirect::new().max_redirect_times(1))
377 .connector(crate::Connector::new())
378 .finish();
379
380 let srv = actix_test::start(|| {
381 App::new()
382 .service(web::resource("/").route(web::to(|| async {
383 Ok::<_, Error>(
384 HttpResponse::Found()
385 .insert_header(("location", "/test"))
386 .finish(),
387 )
388 })))
389 .service(web::resource("/test").route(web::to(|| async {
390 Ok::<_, Error>(
391 HttpResponse::Found()
392 .insert_header(("location", "/test2"))
393 .finish(),
394 )
395 })))
396 .service(web::resource("/test2").route(web::to(|| async {
397 Ok::<_, Error>(HttpResponse::BadRequest())
398 })))
399 });
400
401 let res = client.get(srv.url("/")).send().await.unwrap();
402 assert_eq!(res.status(), StatusCode::FOUND);
403 assert_eq!(
404 res.headers()
405 .get(header::LOCATION)
406 .unwrap()
407 .to_str()
408 .unwrap(),
409 "/test2"
410 );
411 }
412
413 #[actix_rt::test]
414 async fn test_redirect_status_kind_307_308() {
415 let srv = actix_test::start(|| {
416 async fn root() -> HttpResponse {
417 HttpResponse::TemporaryRedirect()
418 .append_header(("location", "/test"))
419 .finish()
420 }
421
422 async fn test(req: HttpRequest, body: Bytes) -> HttpResponse {
423 if req.method() == Method::POST && !body.is_empty() {
424 HttpResponse::Ok().finish()
425 } else {
426 HttpResponse::InternalServerError().finish()
427 }
428 }
429
430 App::new()
431 .service(web::resource("/").route(web::to(root)))
432 .service(web::resource("/test").route(web::to(test)))
433 });
434
435 let res = srv.post("/").send_body("Hello").await.unwrap();
436 assert_eq!(res.status().as_u16(), 200);
437 }
438
439 #[actix_rt::test]
440 async fn test_redirect_status_kind_301_302_303() {
441 let srv = actix_test::start(|| {
442 async fn root() -> HttpResponse {
443 HttpResponse::Found()
444 .append_header(("location", "/test"))
445 .finish()
446 }
447
448 async fn test(req: HttpRequest, body: Bytes) -> HttpResponse {
449 if (req.method() == Method::GET || req.method() == Method::HEAD) && body.is_empty()
450 {
451 HttpResponse::Ok().finish()
452 } else {
453 HttpResponse::InternalServerError().finish()
454 }
455 }
456
457 App::new()
458 .service(web::resource("/").route(web::to(root)))
459 .service(web::resource("/test").route(web::to(test)))
460 });
461
462 let res = srv.post("/").send_body("Hello").await.unwrap();
463 assert_eq!(res.status().as_u16(), 200);
464
465 let res = srv.post("/").send().await.unwrap();
466 assert_eq!(res.status().as_u16(), 200);
467 }
468
469 #[actix_rt::test]
470 async fn test_redirect_headers() {
471 let srv = actix_test::start(|| {
472 async fn root(req: HttpRequest) -> HttpResponse {
473 if req
474 .headers()
475 .get("custom")
476 .unwrap_or(&HeaderValue::from_str("").unwrap())
477 == "value"
478 {
479 HttpResponse::Found()
480 .append_header(("location", "/test"))
481 .finish()
482 } else {
483 HttpResponse::InternalServerError().finish()
484 }
485 }
486
487 async fn test(req: HttpRequest) -> HttpResponse {
488 if req
489 .headers()
490 .get("custom")
491 .unwrap_or(&HeaderValue::from_str("").unwrap())
492 == "value"
493 {
494 HttpResponse::Ok().finish()
495 } else {
496 HttpResponse::InternalServerError().finish()
497 }
498 }
499
500 App::new()
501 .service(web::resource("/").route(web::to(root)))
502 .service(web::resource("/test").route(web::to(test)))
503 });
504
505 let client = ClientBuilder::new()
506 .add_default_header(("custom", "value"))
507 .disable_redirects()
508 .finish();
509 let res = client.get(srv.url("/")).send().await.unwrap();
510 assert_eq!(res.status().as_u16(), 302);
511
512 let client = ClientBuilder::new()
513 .add_default_header(("custom", "value"))
514 .finish();
515 let res = client.get(srv.url("/")).send().await.unwrap();
516 assert_eq!(res.status().as_u16(), 200);
517
518 let client = ClientBuilder::new().finish();
519 let res = client
520 .get(srv.url("/"))
521 .insert_header(("custom", "value"))
522 .send()
523 .await
524 .unwrap();
525 assert_eq!(res.status().as_u16(), 200);
526 }
527
528 #[actix_rt::test]
529 async fn test_redirect_cross_origin_headers() {
530 let srv2 = actix_test::start(|| {
532 async fn root(req: HttpRequest) -> HttpResponse {
533 if req.headers().get(header::AUTHORIZATION).is_none() {
534 HttpResponse::Ok().finish()
535 } else {
536 HttpResponse::InternalServerError().finish()
537 }
538 }
539
540 App::new().service(web::resource("/").route(web::to(root)))
541 });
542 let srv2_port: u16 = srv2.addr().port();
543
544 let srv1 = actix_test::start(move || {
545 async fn root(req: HttpRequest) -> HttpResponse {
546 let port = *req.app_data::<u16>().unwrap();
547 if req.headers().get(header::AUTHORIZATION).is_some() {
548 HttpResponse::Found()
549 .append_header(("location", format!("http://localhost:{}/", port).as_str()))
550 .finish()
551 } else {
552 HttpResponse::InternalServerError().finish()
553 }
554 }
555
556 async fn test1(req: HttpRequest) -> HttpResponse {
557 if req.headers().get(header::AUTHORIZATION).is_some() {
558 HttpResponse::Found()
559 .append_header(("location", "/test2"))
560 .finish()
561 } else {
562 HttpResponse::InternalServerError().finish()
563 }
564 }
565
566 async fn test2(req: HttpRequest) -> HttpResponse {
567 if req.headers().get(header::AUTHORIZATION).is_some() {
568 HttpResponse::Ok().finish()
569 } else {
570 HttpResponse::InternalServerError().finish()
571 }
572 }
573
574 App::new()
575 .app_data(srv2_port)
576 .service(web::resource("/").route(web::to(root)))
577 .service(web::resource("/test1").route(web::to(test1)))
578 .service(web::resource("/test2").route(web::to(test2)))
579 });
580
581 let client = ClientBuilder::new()
583 .add_default_header((header::AUTHORIZATION, "auth_key_value"))
584 .finish();
585 let res = client.get(srv1.url("/")).send().await.unwrap();
586 assert_eq!(res.status().as_u16(), 200);
587
588 let res = client.get(srv1.url("/test1")).send().await.unwrap();
590 assert_eq!(res.status().as_u16(), 200);
591 }
592
593 #[actix_rt::test]
594 async fn test_double_slash_redirect() {
595 let client = ClientBuilder::new()
596 .disable_redirects()
597 .wrap(Redirect::new().max_redirect_times(10))
598 .finish();
599
600 let srv = actix_test::start(|| {
601 App::new()
602 .service(web::resource("/test").route(web::to(|| async {
603 Ok::<_, Error>(HttpResponse::BadRequest())
604 })))
605 .service(
606 web::resource("/").route(web::to(|req: HttpRequest| async move {
607 Ok::<_, Error>(
608 HttpResponse::Found()
609 .append_header((
610 "location",
611 format!(
612 "//localhost:{}/test",
613 req.app_config().local_addr().port()
614 )
615 .as_str(),
616 ))
617 .finish(),
618 )
619 })),
620 )
621 });
622
623 let res = client.get(srv.url("/")).send().await.unwrap();
624
625 assert_eq!(res.status().as_u16(), 400);
626 }
627
628 #[actix_rt::test]
629 async fn test_remove_sensitive_headers() {
630 fn gen_headers() -> header::HeaderMap {
631 let mut headers = header::HeaderMap::new();
632 headers.insert(header::USER_AGENT, HeaderValue::from_str("value").unwrap());
633 headers.insert(
634 header::AUTHORIZATION,
635 HeaderValue::from_str("value").unwrap(),
636 );
637 headers.insert(
638 header::PROXY_AUTHORIZATION,
639 HeaderValue::from_str("value").unwrap(),
640 );
641 headers.insert(header::COOKIE, HeaderValue::from_str("value").unwrap());
642 headers
643 }
644
645 let prev_uri = Uri::from_str("https://host/path1").unwrap();
647 let next_uri = Uri::from_str("https://host/path2").unwrap();
648 let mut headers = gen_headers();
649 remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
650 assert_eq!(headers.len(), 4);
651
652 let prev_uri = Uri::from_str("http://host/").unwrap();
654 let next_uri = Uri::from_str("https://host/").unwrap();
655 let mut headers = gen_headers();
656 remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
657 assert_eq!(headers.len(), 1);
658
659 let prev_uri = Uri::from_str("https://host1/").unwrap();
661 let next_uri = Uri::from_str("https://host2/").unwrap();
662 let mut headers = gen_headers();
663 remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
664 assert_eq!(headers.len(), 1);
665
666 let prev_uri = Uri::from_str("https://host:12/").unwrap();
668 let next_uri = Uri::from_str("https://host:23/").unwrap();
669 let mut headers = gen_headers();
670 remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
671 assert_eq!(headers.len(), 1);
672
673 let prev_uri = Uri::from_str("http://host1:12/path1").unwrap();
675 let next_uri = Uri::from_str("https://host2:23/path2").unwrap();
676 let mut headers = gen_headers();
677 remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
678 assert_eq!(headers.len(), 1);
679 }
680}