1use crate::aws::{AwsCredentialProvider, STORE, STRICT_ENCODE_SET, STRICT_PATH_ENCODE_SET};
19use crate::client::builder::HttpRequestBuilder;
20use crate::client::retry::RetryExt;
21use crate::client::token::{TemporaryToken, TokenCache};
22use crate::client::{HttpClient, HttpError, HttpRequest, TokenProvider};
23use crate::util::{hex_digest, hex_encode, hmac_sha256};
24use crate::{CredentialProvider, Result, RetryConfig};
25use async_trait::async_trait;
26use bytes::Buf;
27use chrono::{DateTime, Utc};
28use http::header::{AUTHORIZATION, HeaderMap, HeaderName, HeaderValue};
29use http::{Method, StatusCode};
30use percent_encoding::utf8_percent_encode;
31use serde::Deserialize;
32use std::collections::BTreeMap;
33use std::sync::Arc;
34use std::time::{Duration, Instant};
35use tracing::warn;
36use url::Url;
37
38#[derive(Debug, thiserror::Error)]
39#[allow(clippy::enum_variant_names)]
40enum Error {
41 #[error("Error performing CreateSession request: {source}")]
42 CreateSessionRequest {
43 source: crate::client::retry::RetryError,
44 },
45
46 #[error("Error getting CreateSession response: {source}")]
47 CreateSessionResponse { source: HttpError },
48
49 #[error("Invalid CreateSessionOutput response: {source}")]
50 CreateSessionOutput { source: quick_xml::DeError },
51}
52
53impl From<Error> for crate::Error {
54 fn from(value: Error) -> Self {
55 Self::Generic {
56 store: STORE,
57 source: Box::new(value),
58 }
59 }
60}
61
62type StdError = Box<dyn std::error::Error + Send + Sync>;
63
64static EMPTY_SHA256_HASH: &str = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
66static UNSIGNED_PAYLOAD: &str = "UNSIGNED-PAYLOAD";
67static STREAMING_PAYLOAD: &str = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD";
68
69#[derive(Eq, PartialEq)]
71pub struct AwsCredential {
72 pub key_id: String,
74 pub secret_key: String,
76 pub token: Option<String>,
78}
79
80impl std::fmt::Debug for AwsCredential {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 f.debug_struct("AwsCredential")
83 .field("key_id", &self.key_id)
84 .field("secret_key", &"******")
85 .field("token", &self.token.as_ref().map(|_| "******"))
86 .finish()
87 }
88}
89
90impl AwsCredential {
91 fn sign(&self, to_sign: &str, date: DateTime<Utc>, region: &str, service: &str) -> String {
95 let date_string = date.format("%Y%m%d").to_string();
96 let date_hmac = hmac_sha256(format!("AWS4{}", self.secret_key), date_string);
97 let region_hmac = hmac_sha256(date_hmac, region);
98 let service_hmac = hmac_sha256(region_hmac, service);
99 let signing_hmac = hmac_sha256(service_hmac, b"aws4_request");
100 hex_encode(hmac_sha256(signing_hmac, to_sign).as_ref())
101 }
102}
103
104#[derive(Debug)]
108pub struct AwsAuthorizer<'a> {
109 date: Option<DateTime<Utc>>,
110 credential: &'a AwsCredential,
111 service: &'a str,
112 region: &'a str,
113 token_header: Option<HeaderName>,
114 sign_payload: bool,
115 request_payer: bool,
116}
117
118static DATE_HEADER: HeaderName = HeaderName::from_static("x-amz-date");
119static HASH_HEADER: HeaderName = HeaderName::from_static("x-amz-content-sha256");
120static TOKEN_HEADER: HeaderName = HeaderName::from_static("x-amz-security-token");
121static REQUEST_PAYER_HEADER: HeaderName = HeaderName::from_static("x-amz-request-payer");
122static REQUEST_PAYER_HEADER_VALUE: HeaderValue = HeaderValue::from_static("requester");
123const ALGORITHM: &str = "AWS4-HMAC-SHA256";
124
125impl<'a> AwsAuthorizer<'a> {
126 pub fn new(credential: &'a AwsCredential, service: &'a str, region: &'a str) -> Self {
128 Self {
129 credential,
130 service,
131 region,
132 date: None,
133 sign_payload: true,
134 token_header: None,
135 request_payer: false,
136 }
137 }
138
139 pub fn with_sign_payload(mut self, signed: bool) -> Self {
142 self.sign_payload = signed;
143 self
144 }
145
146 pub(crate) fn with_token_header(mut self, header: HeaderName) -> Self {
148 self.token_header = Some(header);
149 self
150 }
151
152 pub fn with_request_payer(mut self, request_payer: bool) -> Self {
156 self.request_payer = request_payer;
157 self
158 }
159
160 pub fn authorize(&self, request: &mut HttpRequest, pre_calculated_digest: Option<&[u8]>) {
174 let url = Url::parse(&request.uri().to_string()).unwrap();
175
176 if let Some(ref token) = self.credential.token {
177 let token_val = HeaderValue::from_str(token).unwrap();
178 let header = self.token_header.as_ref().unwrap_or(&TOKEN_HEADER);
179 request.headers_mut().insert(header, token_val);
180 }
181
182 let host = &url[url::Position::BeforeHost..url::Position::AfterPort];
183 let host_val = HeaderValue::from_str(host).unwrap();
184 request.headers_mut().insert("host", host_val);
185
186 let date = self.date.unwrap_or_else(Utc::now);
187 let date_str = date.format("%Y%m%dT%H%M%SZ").to_string();
188 let date_val = HeaderValue::from_str(&date_str).unwrap();
189 request.headers_mut().insert(&DATE_HEADER, date_val);
190
191 let digest = match self.sign_payload {
192 false => UNSIGNED_PAYLOAD.to_string(),
193 true => match pre_calculated_digest {
194 Some(digest) => hex_encode(digest),
195 None => match request.body().is_empty() {
196 true => EMPTY_SHA256_HASH.to_string(),
197 false => match request.body().as_bytes() {
198 Some(bytes) => hex_digest(bytes),
199 None => STREAMING_PAYLOAD.to_string(),
200 },
201 },
202 },
203 };
204
205 let header_digest = HeaderValue::from_str(&digest).unwrap();
206 request.headers_mut().insert(&HASH_HEADER, header_digest);
207
208 if self.request_payer {
209 request
213 .headers_mut()
214 .insert(&REQUEST_PAYER_HEADER, REQUEST_PAYER_HEADER_VALUE.clone());
215 }
216
217 let (signed_headers, canonical_headers) = canonicalize_headers(request.headers());
218
219 let scope = self.scope(date);
220
221 let string_to_sign = self.string_to_sign(
222 date,
223 &scope,
224 request.method(),
225 &url,
226 &canonical_headers,
227 &signed_headers,
228 &digest,
229 );
230
231 let signature = self
233 .credential
234 .sign(&string_to_sign, date, self.region, self.service);
235
236 let authorisation = format!(
238 "{} Credential={}/{}, SignedHeaders={}, Signature={}",
239 ALGORITHM, self.credential.key_id, scope, signed_headers, signature
240 );
241
242 let authorization_val = HeaderValue::from_str(&authorisation).unwrap();
243 request
244 .headers_mut()
245 .insert(&AUTHORIZATION, authorization_val);
246 }
247
248 pub(crate) fn sign(&self, method: Method, url: &mut Url, expires_in: Duration) {
249 let date = self.date.unwrap_or_else(Utc::now);
250 let scope = self.scope(date);
251
252 url.query_pairs_mut()
254 .append_pair("X-Amz-Algorithm", ALGORITHM)
255 .append_pair(
256 "X-Amz-Credential",
257 &format!("{}/{}", self.credential.key_id, scope),
258 )
259 .append_pair("X-Amz-Date", &date.format("%Y%m%dT%H%M%SZ").to_string())
260 .append_pair("X-Amz-Expires", &expires_in.as_secs().to_string())
261 .append_pair("X-Amz-SignedHeaders", "host");
262
263 if self.request_payer {
264 url.query_pairs_mut()
267 .append_pair("x-amz-request-payer", "requester");
268 }
269
270 if let Some(ref token) = self.credential.token {
273 url.query_pairs_mut()
274 .append_pair("X-Amz-Security-Token", token);
275 }
276
277 let digest = UNSIGNED_PAYLOAD;
279
280 let host = &url[url::Position::BeforeHost..url::Position::AfterPort].to_string();
281 let mut headers = HeaderMap::new();
282 let host_val = HeaderValue::from_str(host).unwrap();
283 headers.insert("host", host_val);
284
285 let (signed_headers, canonical_headers) = canonicalize_headers(&headers);
286
287 let string_to_sign = self.string_to_sign(
288 date,
289 &scope,
290 &method,
291 url,
292 &canonical_headers,
293 &signed_headers,
294 digest,
295 );
296
297 let signature = self
298 .credential
299 .sign(&string_to_sign, date, self.region, self.service);
300
301 url.query_pairs_mut()
302 .append_pair("X-Amz-Signature", &signature);
303 }
304
305 #[allow(clippy::too_many_arguments)]
306 fn string_to_sign(
307 &self,
308 date: DateTime<Utc>,
309 scope: &str,
310 request_method: &Method,
311 url: &Url,
312 canonical_headers: &str,
313 signed_headers: &str,
314 digest: &str,
315 ) -> String {
316 let canonical_uri = match self.service {
320 "s3" => url.path().to_string(),
321 _ => utf8_percent_encode(url.path(), &STRICT_PATH_ENCODE_SET).to_string(),
322 };
323
324 let canonical_query = canonicalize_query(url);
325
326 let canonical_request = format!(
328 "{}\n{}\n{}\n{}\n{}\n{}",
329 request_method.as_str(),
330 canonical_uri,
331 canonical_query,
332 canonical_headers,
333 signed_headers,
334 digest
335 );
336
337 let hashed_canonical_request = hex_digest(canonical_request.as_bytes());
338
339 format!(
340 "{}\n{}\n{}\n{}",
341 ALGORITHM,
342 date.format("%Y%m%dT%H%M%SZ"),
343 scope,
344 hashed_canonical_request
345 )
346 }
347
348 fn scope(&self, date: DateTime<Utc>) -> String {
349 format!(
350 "{}/{}/{}/aws4_request",
351 date.format("%Y%m%d"),
352 self.region,
353 self.service
354 )
355 }
356}
357
358pub(crate) trait CredentialExt {
359 fn with_aws_sigv4(
361 self,
362 authorizer: Option<AwsAuthorizer<'_>>,
363 payload_sha256: Option<&[u8]>,
364 ) -> Self;
365}
366
367impl CredentialExt for HttpRequestBuilder {
368 fn with_aws_sigv4(
369 self,
370 authorizer: Option<AwsAuthorizer<'_>>,
371 payload_sha256: Option<&[u8]>,
372 ) -> Self {
373 match authorizer {
374 Some(authorizer) => {
375 let (client, request) = self.into_parts();
376 let mut request = request.expect("request valid");
377 authorizer.authorize(&mut request, payload_sha256);
378
379 Self::from_parts(client, request)
380 }
381 None => self,
382 }
383 }
384}
385
386fn canonicalize_query(url: &Url) -> String {
390 use std::fmt::Write;
391
392 let capacity = match url.query() {
393 Some(q) if !q.is_empty() => q.len(),
394 _ => return String::new(),
395 };
396 let mut encoded = String::with_capacity(capacity + 1);
397
398 let mut headers = url.query_pairs().collect::<Vec<_>>();
399 headers.sort_unstable_by(|(a, _), (b, _)| a.cmp(b));
400
401 let mut first = true;
402 for (k, v) in headers {
403 if !first {
404 encoded.push('&');
405 }
406 first = false;
407 let _ = write!(
408 encoded,
409 "{}={}",
410 utf8_percent_encode(k.as_ref(), &STRICT_ENCODE_SET),
411 utf8_percent_encode(v.as_ref(), &STRICT_ENCODE_SET)
412 );
413 }
414 encoded
415}
416
417fn append_normalized_whitespace_value(headers: &'_ mut String, input: &str) {
418 let mut iter = input.split_whitespace();
419
420 if let Some(first) = iter.next() {
421 headers.push_str(first);
422 for word in iter {
423 headers.push(' ');
424 headers.push_str(word);
425 }
426 }
427}
428
429fn canonicalize_headers(header_map: &HeaderMap) -> (String, String) {
433 let mut headers = BTreeMap::<&str, Vec<&str>>::new();
434 let mut value_count = 0;
435 let mut value_bytes = 0;
436 let mut key_bytes = 0;
437
438 for (key, value) in header_map {
439 let key = key.as_str();
440 if ["authorization", "content-length", "user-agent"].contains(&key) {
441 continue;
442 }
443
444 let value = std::str::from_utf8(value.as_bytes()).unwrap();
445 key_bytes += key.len();
446 value_bytes += value.len();
447 value_count += 1;
448 headers.entry(key).or_default().push(value);
449 }
450
451 let mut signed_headers = String::with_capacity(key_bytes + headers.len());
452 let mut canonical_headers =
453 String::with_capacity(key_bytes + value_bytes + headers.len() + value_count);
454
455 for (header_idx, (name, values)) in headers.into_iter().enumerate() {
456 if header_idx != 0 {
457 signed_headers.push(';');
458 }
459
460 signed_headers.push_str(name);
461 canonical_headers.push_str(name);
462 canonical_headers.push(':');
463 for (value_idx, value) in values.into_iter().enumerate() {
464 if value_idx != 0 {
465 canonical_headers.push(',');
466 }
467 append_normalized_whitespace_value(&mut canonical_headers, value.trim());
468 }
469 canonical_headers.push('\n');
470 }
471
472 (signed_headers, canonical_headers)
473}
474
475#[derive(Debug)]
479pub(crate) struct InstanceCredentialProvider {
480 pub imdsv1_fallback: bool,
481 pub metadata_endpoint: String,
482}
483
484#[async_trait]
485impl TokenProvider for InstanceCredentialProvider {
486 type Credential = AwsCredential;
487
488 async fn fetch_token(
489 &self,
490 client: &HttpClient,
491 retry: &RetryConfig,
492 ) -> Result<TemporaryToken<Arc<AwsCredential>>> {
493 instance_creds(client, retry, &self.metadata_endpoint, self.imdsv1_fallback)
494 .await
495 .map_err(|source| crate::Error::Generic {
496 store: STORE,
497 source,
498 })
499 }
500}
501
502#[derive(Debug)]
506pub(crate) struct WebIdentityProvider {
507 pub token_path: String,
508 pub role_arn: String,
509 pub session_name: String,
510 pub endpoint: String,
511}
512
513#[async_trait]
514impl TokenProvider for WebIdentityProvider {
515 type Credential = AwsCredential;
516
517 async fn fetch_token(
518 &self,
519 client: &HttpClient,
520 retry: &RetryConfig,
521 ) -> Result<TemporaryToken<Arc<AwsCredential>>> {
522 web_identity(
523 client,
524 retry,
525 &self.token_path,
526 &self.role_arn,
527 &self.session_name,
528 &self.endpoint,
529 )
530 .await
531 .map_err(|source| crate::Error::Generic {
532 store: STORE,
533 source,
534 })
535 }
536}
537
538#[derive(Debug, Deserialize)]
539#[serde(rename_all = "PascalCase")]
540struct InstanceCredentials {
541 access_key_id: String,
542 secret_access_key: String,
543 token: String,
544 expiration: DateTime<Utc>,
545}
546
547impl From<InstanceCredentials> for AwsCredential {
548 fn from(s: InstanceCredentials) -> Self {
549 Self {
550 key_id: s.access_key_id,
551 secret_key: s.secret_access_key,
552 token: Some(s.token),
553 }
554 }
555}
556
557async fn instance_creds(
559 client: &HttpClient,
560 retry_config: &RetryConfig,
561 endpoint: &str,
562 imdsv1_fallback: bool,
563) -> Result<TemporaryToken<Arc<AwsCredential>>, StdError> {
564 const CREDENTIALS_PATH: &str = "latest/meta-data/iam/security-credentials";
565 const AWS_EC2_METADATA_TOKEN_HEADER: &str = "X-aws-ec2-metadata-token";
566
567 let token_url = format!("{endpoint}/latest/api/token");
568
569 let token_result = client
570 .request(Method::PUT, token_url)
571 .header("X-aws-ec2-metadata-token-ttl-seconds", "600") .retryable(retry_config)
573 .idempotent(true)
574 .send()
575 .await;
576
577 let token = match token_result {
578 Ok(t) => Some(t.into_body().text().await?),
579 Err(e) if imdsv1_fallback && matches!(e.status(), Some(StatusCode::FORBIDDEN)) => {
580 warn!("received 403 from metadata endpoint, falling back to IMDSv1");
581 None
582 }
583 Err(e) => return Err(e.into()),
584 };
585
586 let role_url = format!("{endpoint}/{CREDENTIALS_PATH}/");
587 let mut role_request = client.request(Method::GET, role_url);
588
589 if let Some(token) = &token {
590 role_request = role_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token);
591 }
592
593 let role = role_request
594 .send_retry(retry_config)
595 .await?
596 .into_body()
597 .text()
598 .await?;
599
600 let creds_url = format!("{endpoint}/{CREDENTIALS_PATH}/{role}");
601 let mut creds_request = client.request(Method::GET, creds_url);
602 if let Some(token) = &token {
603 creds_request = creds_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token);
604 }
605
606 let creds: InstanceCredentials = creds_request
607 .send_retry(retry_config)
608 .await?
609 .into_body()
610 .json()
611 .await?;
612
613 let now = Utc::now();
614 let ttl = (creds.expiration - now).to_std().unwrap_or_default();
615 Ok(TemporaryToken {
616 token: Arc::new(creds.into()),
617 expiry: Some(Instant::now() + ttl),
618 })
619}
620
621#[derive(Debug, Deserialize)]
622#[serde(rename_all = "PascalCase")]
623struct AssumeRoleResponse {
624 assume_role_with_web_identity_result: AssumeRoleResult,
625}
626
627#[derive(Debug, Deserialize)]
628#[serde(rename_all = "PascalCase")]
629struct AssumeRoleResult {
630 credentials: SessionCredentials,
631}
632
633#[derive(Debug, Deserialize)]
634#[serde(rename_all = "PascalCase")]
635struct SessionCredentials {
636 session_token: String,
637 secret_access_key: String,
638 access_key_id: String,
639 expiration: DateTime<Utc>,
640}
641
642impl From<SessionCredentials> for AwsCredential {
643 fn from(s: SessionCredentials) -> Self {
644 Self {
645 key_id: s.access_key_id,
646 secret_key: s.secret_access_key,
647 token: Some(s.session_token),
648 }
649 }
650}
651
652async fn web_identity(
654 client: &HttpClient,
655 retry_config: &RetryConfig,
656 token_path: &str,
657 role_arn: &str,
658 session_name: &str,
659 endpoint: &str,
660) -> Result<TemporaryToken<Arc<AwsCredential>>, StdError> {
661 let token = std::fs::read_to_string(token_path)
662 .map_err(|e| format!("Failed to read token file '{token_path}': {e}"))?;
663
664 let bytes = client
665 .post(endpoint)
666 .query(&[
667 ("Action", "AssumeRoleWithWebIdentity"),
668 ("DurationSeconds", "3600"),
669 ("RoleArn", role_arn),
670 ("RoleSessionName", session_name),
671 ("Version", "2011-06-15"),
672 ("WebIdentityToken", &token),
673 ])
674 .retryable(retry_config)
675 .idempotent(true)
676 .sensitive(true)
677 .send()
678 .await?
679 .into_body()
680 .bytes()
681 .await?;
682
683 let resp: AssumeRoleResponse = quick_xml::de::from_reader(bytes.reader())
684 .map_err(|e| format!("Invalid AssumeRoleWithWebIdentity response: {e}"))?;
685
686 let creds = resp.assume_role_with_web_identity_result.credentials;
687 let now = Utc::now();
688 let ttl = (creds.expiration - now).to_std().unwrap_or_default();
689
690 Ok(TemporaryToken {
691 token: Arc::new(creds.into()),
692 expiry: Some(Instant::now() + ttl),
693 })
694}
695
696#[derive(Debug)]
700pub(crate) struct TaskCredentialProvider {
701 pub url: String,
702 pub retry: RetryConfig,
703 pub client: HttpClient,
704 pub cache: TokenCache<Arc<AwsCredential>>,
705}
706
707#[async_trait]
708impl CredentialProvider for TaskCredentialProvider {
709 type Credential = AwsCredential;
710
711 async fn get_credential(&self) -> Result<Arc<AwsCredential>> {
712 self.cache
713 .get_or_insert_with(|| task_credential(&self.client, &self.retry, &self.url))
714 .await
715 .map_err(|source| crate::Error::Generic {
716 store: STORE,
717 source,
718 })
719 }
720}
721
722async fn task_credential(
724 client: &HttpClient,
725 retry: &RetryConfig,
726 url: &str,
727) -> Result<TemporaryToken<Arc<AwsCredential>>, StdError> {
728 let creds: InstanceCredentials = client
729 .get(url)
730 .send_retry(retry)
731 .await?
732 .into_body()
733 .json()
734 .await?;
735
736 let now = Utc::now();
737 let ttl = (creds.expiration - now).to_std().unwrap_or_default();
738 Ok(TemporaryToken {
739 token: Arc::new(creds.into()),
740 expiry: Some(Instant::now() + ttl),
741 })
742}
743
744#[derive(Debug)]
750pub(crate) struct EKSPodCredentialProvider {
751 pub url: String,
752 pub token_file: String,
753 pub retry: RetryConfig,
754 pub client: HttpClient,
755 pub cache: TokenCache<Arc<AwsCredential>>,
756}
757
758#[async_trait]
759impl CredentialProvider for EKSPodCredentialProvider {
760 type Credential = AwsCredential;
761
762 async fn get_credential(&self) -> Result<Arc<AwsCredential>> {
763 self.cache
764 .get_or_insert_with(|| {
765 eks_credential(&self.client, &self.retry, &self.url, &self.token_file)
766 })
767 .await
768 .map_err(|source| crate::Error::Generic {
769 store: STORE,
770 source,
771 })
772 }
773}
774
775async fn eks_credential(
779 client: &HttpClient,
780 retry: &RetryConfig,
781 url: &str,
782 token_file: &str,
783) -> Result<TemporaryToken<Arc<AwsCredential>>, StdError> {
784 let token = match tokio::runtime::Handle::try_current() {
786 Ok(runtime) => {
787 let path = token_file.to_string();
788 runtime
789 .spawn_blocking(move || std::fs::read_to_string(&path))
790 .await?
791 }
792 Err(_) => std::fs::read_to_string(token_file),
793 }
794 .map_err(|e| format!("Failed to read EKS token file '{token_file}': {e}"))?;
795
796 let mut req = client.request(Method::GET, url);
797 req = req.header("Authorization", token);
798
799 let creds: InstanceCredentials = req.send_retry(retry).await?.into_body().json().await?;
801
802 let now = Utc::now();
803 let ttl = (creds.expiration - now).to_std().unwrap_or_default();
804
805 Ok(TemporaryToken {
806 token: Arc::new(creds.into()),
807 expiry: Some(Instant::now() + ttl),
808 })
809}
810
811#[derive(Debug)]
815pub(crate) struct SessionProvider {
816 pub endpoint: String,
817 pub region: String,
818 pub credentials: AwsCredentialProvider,
819}
820
821#[async_trait]
822impl TokenProvider for SessionProvider {
823 type Credential = AwsCredential;
824
825 async fn fetch_token(
826 &self,
827 client: &HttpClient,
828 retry: &RetryConfig,
829 ) -> Result<TemporaryToken<Arc<Self::Credential>>> {
830 let creds = self.credentials.get_credential().await?;
831 let authorizer = AwsAuthorizer::new(&creds, "s3", &self.region);
832
833 let bytes = client
834 .get(format!("{}?session", self.endpoint))
835 .with_aws_sigv4(Some(authorizer), None)
836 .send_retry(retry)
837 .await
838 .map_err(|source| Error::CreateSessionRequest { source })?
839 .into_body()
840 .bytes()
841 .await
842 .map_err(|source| Error::CreateSessionResponse { source })?;
843
844 let resp: CreateSessionOutput = quick_xml::de::from_reader(bytes.reader())
845 .map_err(|source| Error::CreateSessionOutput { source })?;
846
847 let creds = resp.credentials;
848 Ok(TemporaryToken {
849 token: Arc::new(creds.into()),
850 expiry: Some(Instant::now() + Duration::from_secs(5 * 60)),
852 })
853 }
854}
855
856#[derive(Debug, Deserialize)]
857#[serde(rename_all = "PascalCase")]
858struct CreateSessionOutput {
859 credentials: SessionCredentials,
860}
861
862#[cfg(test)]
863mod tests {
864 use super::*;
865 use crate::aws::{AmazonS3Builder, AmazonS3ConfigKey};
866 use crate::client::HttpClient;
867 use crate::client::mock_server::MockServer;
868 use http::Response;
869 use reqwest::{Client, Method};
870 use std::env;
871
872 #[test]
874 fn test_sign_with_signed_payload() {
875 let client = HttpClient::new(Client::new());
876
877 let credential = AwsCredential {
879 key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
880 secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
881 token: None,
882 };
883
884 let date = DateTime::parse_from_rfc3339("2022-08-06T18:01:34Z")
891 .unwrap()
892 .with_timezone(&Utc);
893
894 let mut request = client
895 .request(Method::GET, "https://ec2.amazon.com/")
896 .into_parts()
897 .1
898 .unwrap();
899
900 let signer = AwsAuthorizer {
901 date: Some(date),
902 credential: &credential,
903 service: "ec2",
904 region: "us-east-1",
905 sign_payload: true,
906 token_header: None,
907 request_payer: false,
908 };
909
910 signer.authorize(&mut request, None);
911 assert_eq!(
912 request.headers().get(&AUTHORIZATION).unwrap(),
913 "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=a3c787a7ed37f7fdfbfd2d7056a3d7c9d85e6d52a2bfbec73793c0be6e7862d4"
914 )
915 }
916
917 #[test]
918 fn test_sign_with_signed_payload_request_payer() {
919 let client = HttpClient::new(Client::new());
920
921 let credential = AwsCredential {
923 key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
924 secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
925 token: None,
926 };
927
928 let date = DateTime::parse_from_rfc3339("2022-08-06T18:01:34Z")
935 .unwrap()
936 .with_timezone(&Utc);
937
938 let mut request = client
939 .request(Method::GET, "https://ec2.amazon.com/")
940 .into_parts()
941 .1
942 .unwrap();
943
944 let signer = AwsAuthorizer {
945 date: Some(date),
946 credential: &credential,
947 service: "ec2",
948 region: "us-east-1",
949 sign_payload: true,
950 token_header: None,
951 request_payer: true,
952 };
953
954 signer.authorize(&mut request, None);
955 assert_eq!(
956 request.headers().get(&AUTHORIZATION).unwrap(),
957 "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date;x-amz-request-payer, Signature=7030625a9e9b57ed2a40e63d749f4a4b7714b6e15004cab026152f870dd8565d"
958 )
959 }
960
961 #[test]
962 fn test_sign_with_unsigned_payload() {
963 let client = HttpClient::new(Client::new());
964
965 let credential = AwsCredential {
967 key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
968 secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
969 token: None,
970 };
971
972 let date = DateTime::parse_from_rfc3339("2022-08-06T18:01:34Z")
979 .unwrap()
980 .with_timezone(&Utc);
981
982 let mut request = client
983 .request(Method::GET, "https://ec2.amazon.com/")
984 .into_parts()
985 .1
986 .unwrap();
987
988 let authorizer = AwsAuthorizer {
989 date: Some(date),
990 credential: &credential,
991 service: "ec2",
992 region: "us-east-1",
993 token_header: None,
994 sign_payload: false,
995 request_payer: false,
996 };
997
998 authorizer.authorize(&mut request, None);
999 assert_eq!(
1000 request.headers().get(&AUTHORIZATION).unwrap(),
1001 "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=653c3d8ea261fd826207df58bc2bb69fbb5003e9eb3c0ef06e4a51f2a81d8699"
1002 );
1003 }
1004
1005 #[test]
1006 fn signed_get_url() {
1007 let credential = AwsCredential {
1009 key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
1010 secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
1011 token: None,
1012 };
1013
1014 let date = DateTime::parse_from_rfc3339("2013-05-24T00:00:00Z")
1015 .unwrap()
1016 .with_timezone(&Utc);
1017
1018 let authorizer = AwsAuthorizer {
1019 date: Some(date),
1020 credential: &credential,
1021 service: "s3",
1022 region: "us-east-1",
1023 token_header: None,
1024 sign_payload: false,
1025 request_payer: false,
1026 };
1027
1028 let mut url = Url::parse("https://examplebucket.s3.amazonaws.com/test.txt").unwrap();
1029 authorizer.sign(Method::GET, &mut url, Duration::from_secs(86400));
1030
1031 assert_eq!(
1032 url,
1033 Url::parse(
1034 "https://examplebucket.s3.amazonaws.com/test.txt?\
1035 X-Amz-Algorithm=AWS4-HMAC-SHA256&\
1036 X-Amz-Credential=AKIAIOSFODNN7EXAMPLE%2F20130524%2Fus-east-1%2Fs3%2Faws4_request&\
1037 X-Amz-Date=20130524T000000Z&\
1038 X-Amz-Expires=86400&\
1039 X-Amz-SignedHeaders=host&\
1040 X-Amz-Signature=aeeed9bbccd4d02ee5c0109b86d86835f995330da4c265957d157751f604d404"
1041 )
1042 .unwrap()
1043 );
1044 }
1045
1046 #[test]
1047 fn signed_get_url_request_payer() {
1048 let credential = AwsCredential {
1050 key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
1051 secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
1052 token: None,
1053 };
1054
1055 let date = DateTime::parse_from_rfc3339("2013-05-24T00:00:00Z")
1056 .unwrap()
1057 .with_timezone(&Utc);
1058
1059 let authorizer = AwsAuthorizer {
1060 date: Some(date),
1061 credential: &credential,
1062 service: "s3",
1063 region: "us-east-1",
1064 token_header: None,
1065 sign_payload: false,
1066 request_payer: true,
1067 };
1068
1069 let mut url = Url::parse("https://examplebucket.s3.amazonaws.com/test.txt").unwrap();
1070 authorizer.sign(Method::GET, &mut url, Duration::from_secs(86400));
1071
1072 assert_eq!(
1073 url,
1074 Url::parse(
1075 "https://examplebucket.s3.amazonaws.com/test.txt?\
1076 X-Amz-Algorithm=AWS4-HMAC-SHA256&\
1077 X-Amz-Credential=AKIAIOSFODNN7EXAMPLE%2F20130524%2Fus-east-1%2Fs3%2Faws4_request&\
1078 X-Amz-Date=20130524T000000Z&\
1079 X-Amz-Expires=86400&\
1080 X-Amz-SignedHeaders=host&\
1081 x-amz-request-payer=requester&\
1082 X-Amz-Signature=9ad7c781cc30121f199b47d35ed3528473e4375b63c5d91cd87c927803e4e00a"
1083 )
1084 .unwrap()
1085 );
1086 }
1087
1088 #[test]
1089 fn test_sign_port() {
1090 let client = HttpClient::new(Client::new());
1091
1092 let credential = AwsCredential {
1093 key_id: "H20ABqCkLZID4rLe".to_string(),
1094 secret_key: "jMqRDgxSsBqqznfmddGdu1TmmZOJQxdM".to_string(),
1095 token: None,
1096 };
1097
1098 let date = DateTime::parse_from_rfc3339("2022-08-09T13:05:25Z")
1099 .unwrap()
1100 .with_timezone(&Utc);
1101
1102 let mut request = client
1103 .request(Method::GET, "http://localhost:9000/tsm-schemas")
1104 .query(&[
1105 ("delimiter", "/"),
1106 ("encoding-type", "url"),
1107 ("list-type", "2"),
1108 ("prefix", ""),
1109 ])
1110 .into_parts()
1111 .1
1112 .unwrap();
1113
1114 let authorizer = AwsAuthorizer {
1115 date: Some(date),
1116 credential: &credential,
1117 service: "s3",
1118 region: "us-east-1",
1119 token_header: None,
1120 sign_payload: true,
1121 request_payer: false,
1122 };
1123
1124 authorizer.authorize(&mut request, None);
1125 assert_eq!(
1126 request.headers().get(&AUTHORIZATION).unwrap(),
1127 "AWS4-HMAC-SHA256 Credential=H20ABqCkLZID4rLe/20220809/us-east-1/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=9ebf2f92872066c99ac94e573b4e1b80f4dbb8a32b1e8e23178318746e7d1b4d"
1128 )
1129 }
1130
1131 #[tokio::test]
1132 async fn test_instance_metadata() {
1133 if env::var("TEST_INTEGRATION").is_err() {
1134 eprintln!("skipping AWS integration test");
1135 return;
1136 }
1137
1138 let endpoint = env::var("EC2_METADATA_ENDPOINT").unwrap();
1140 let client = HttpClient::new(Client::new());
1141 let retry_config = RetryConfig::default();
1142
1143 let (client, req) = client
1145 .request(Method::GET, format!("{endpoint}/latest/meta-data/ami-id"))
1146 .into_parts();
1147
1148 let resp = client.execute(req.unwrap()).await.unwrap();
1149
1150 assert_eq!(
1151 resp.status(),
1152 StatusCode::UNAUTHORIZED,
1153 "Ensure metadata endpoint is set to only allow IMDSv2"
1154 );
1155
1156 let creds = instance_creds(&client, &retry_config, &endpoint, false)
1157 .await
1158 .unwrap();
1159
1160 let id = &creds.token.key_id;
1161 let secret = &creds.token.secret_key;
1162 let token = creds.token.token.as_ref().unwrap();
1163
1164 assert!(!id.is_empty());
1165 assert!(!secret.is_empty());
1166 assert!(!token.is_empty())
1167 }
1168
1169 #[tokio::test]
1170 async fn test_mock() {
1171 let server = MockServer::new().await;
1172
1173 const IMDSV2_HEADER: &str = "X-aws-ec2-metadata-token";
1174
1175 let secret_access_key = "SECRET";
1176 let access_key_id = "KEYID";
1177 let token = "TOKEN";
1178
1179 let endpoint = server.url();
1180 let client = HttpClient::new(Client::new());
1181 let retry_config = RetryConfig::default();
1182
1183 server.push_fn(|req| {
1185 assert_eq!(req.uri().path(), "/latest/api/token");
1186 assert_eq!(req.method(), &Method::PUT);
1187 Response::new("cupcakes".to_string())
1188 });
1189 server.push_fn(|req| {
1190 assert_eq!(
1191 req.uri().path(),
1192 "/latest/meta-data/iam/security-credentials/"
1193 );
1194 assert_eq!(req.method(), &Method::GET);
1195 let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap();
1196 assert_eq!(t, "cupcakes");
1197 Response::new("myrole".to_string())
1198 });
1199 server.push_fn(|req| {
1200 assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole");
1201 assert_eq!(req.method(), &Method::GET);
1202 let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap();
1203 assert_eq!(t, "cupcakes");
1204 Response::new(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#.to_string())
1205 });
1206
1207 let creds = instance_creds(&client, &retry_config, endpoint, true)
1208 .await
1209 .unwrap();
1210
1211 assert_eq!(creds.token.token.as_deref().unwrap(), token);
1212 assert_eq!(&creds.token.key_id, access_key_id);
1213 assert_eq!(&creds.token.secret_key, secret_access_key);
1214
1215 server.push_fn(|req| {
1217 assert_eq!(req.uri().path(), "/latest/api/token");
1218 assert_eq!(req.method(), &Method::PUT);
1219 Response::builder()
1220 .status(StatusCode::FORBIDDEN)
1221 .body(String::new())
1222 .unwrap()
1223 });
1224 server.push_fn(|req| {
1225 assert_eq!(
1226 req.uri().path(),
1227 "/latest/meta-data/iam/security-credentials/"
1228 );
1229 assert_eq!(req.method(), &Method::GET);
1230 assert!(req.headers().get(IMDSV2_HEADER).is_none());
1231 Response::new("myrole".to_string())
1232 });
1233 server.push_fn(|req| {
1234 assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole");
1235 assert_eq!(req.method(), &Method::GET);
1236 assert!(req.headers().get(IMDSV2_HEADER).is_none());
1237 Response::new(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#.to_string())
1238 });
1239
1240 let creds = instance_creds(&client, &retry_config, endpoint, true)
1241 .await
1242 .unwrap();
1243
1244 assert_eq!(creds.token.token.as_deref().unwrap(), token);
1245 assert_eq!(&creds.token.key_id, access_key_id);
1246 assert_eq!(&creds.token.secret_key, secret_access_key);
1247
1248 server.push(
1250 Response::builder()
1251 .status(StatusCode::FORBIDDEN)
1252 .body(String::new())
1253 .unwrap(),
1254 );
1255
1256 instance_creds(&client, &retry_config, endpoint, false)
1258 .await
1259 .unwrap_err();
1260 }
1261
1262 #[tokio::test]
1263 async fn test_eks_pod_credential_provider() {
1264 use crate::client::mock_server::MockServer;
1265 use http::Response;
1266 use std::fs::File;
1267 use std::io::Write;
1268
1269 let mock_server = MockServer::new().await;
1270
1271 mock_server.push(Response::new(
1272 r#"{
1273 "AccessKeyId": "TEST_KEY",
1274 "SecretAccessKey": "TEST_SECRET",
1275 "Token": "TEST_SESSION_TOKEN",
1276 "Expiration": "2100-01-01T00:00:00Z"
1277 }"#
1278 .to_string(),
1279 ));
1280
1281 let token_file = tempfile::NamedTempFile::new().expect("cannot create temp file");
1282 let path = token_file.path().to_string_lossy().into_owned();
1283 let mut f = File::create(token_file.path()).unwrap();
1284 write!(f, "TEST_BEARER_TOKEN").unwrap();
1285
1286 let builder = AmazonS3Builder::new()
1287 .with_bucket_name("some-bucket")
1288 .with_config(
1289 AmazonS3ConfigKey::ContainerCredentialsFullUri,
1290 mock_server.url(),
1291 )
1292 .with_config(AmazonS3ConfigKey::ContainerAuthorizationTokenFile, &path);
1293
1294 let s3 = builder.build().unwrap();
1295
1296 let cred = s3.client.config.credentials.get_credential().await.unwrap();
1297
1298 assert_eq!(cred.key_id, "TEST_KEY");
1299 assert_eq!(cred.secret_key, "TEST_SECRET");
1300 assert_eq!(cred.token.as_deref(), Some("TEST_SESSION_TOKEN"));
1301 }
1302
1303 #[test]
1304 fn test_output_masks_all_fields() {
1305 let cred = AwsCredential {
1306 key_id: "AKIAXXX".to_string(),
1307 secret_key: "super_secret".to_string(),
1308 token: Some("temp_token".to_string()),
1309 };
1310
1311 let debug_output = format!("{cred:?}");
1312
1313 assert!(debug_output.contains("key_id: \"AKIAXXX\""));
1314 assert!(debug_output.contains("secret_key: \"******\""));
1315 assert!(debug_output.contains("token: Some(\"******\")"));
1316
1317 assert!(!debug_output.contains("super_secret"));
1318 assert!(!debug_output.contains("temp_token"));
1319 }
1320
1321 #[test]
1322 fn test_normalize_whitespace() {
1323 let test_cases = vec![
1325 ("本語", "本語"),
1327 (" abc ", "abc"),
1328 (" a b ", "a b"),
1329 ("a b ", "a b"),
1330 ("a b", "a b"),
1331 ("a b", "a b"),
1332 (" a b c ", "a b c"),
1333 ("a \t b c ", "a b c"),
1334 ("\"a \t b c ", "\"a b c"),
1335 (
1336 " \t\n\u{000b}\r\u{000c}a \t\n\u{000b}\r\u{000c} b \t\n\u{000b}\r\u{000c} c \t\n\u{000b}\r\u{000c}",
1337 "a b c",
1338 ),
1339 ];
1340
1341 for (input, expected) in test_cases {
1342 let mut headers = String::new();
1343
1344 append_normalized_whitespace_value(&mut headers, input);
1345 assert_eq!(headers, expected);
1346 }
1347 }
1348
1349 #[test]
1350 fn test_canonicalize_headers_whitespace_normalization() {
1351 use http::header::HeaderMap;
1352
1353 let mut headers = HeaderMap::new();
1354 headers.insert("x-amz-meta-example", " foo bar ".parse().unwrap());
1355 headers.insert(
1356 "x-amz-meta-another",
1357 " multiple spaces here ".parse().unwrap(),
1358 );
1359 headers.insert(
1360 "x-amz-meta-and-another-one",
1361 "foo\t\t\t bar".parse().unwrap(),
1362 );
1363 headers.insert("authorization", "SHOULD_BE_IGNORED".parse().unwrap());
1365 headers.insert("content-length", "1337".parse().unwrap());
1366
1367 let (signed_headers, canonical_headers) = super::canonicalize_headers(&headers);
1368
1369 assert_eq!(
1370 signed_headers,
1371 "x-amz-meta-and-another-one;x-amz-meta-another;x-amz-meta-example"
1372 );
1373
1374 let expected_canonical_headers = "x-amz-meta-and-another-one:foo bar\n\
1375 x-amz-meta-another:multiple spaces here\n\
1376 x-amz-meta-example:foo bar\n";
1377 assert_eq!(canonical_headers, expected_canonical_headers);
1378 }
1379}