1use std::borrow::{Borrow as _, Cow};
8use std::fmt;
9
10use base64ct::{Base64UrlUnpadded, Encoding as _};
11use hmac::Mac as _;
12use rsa::{
13 pkcs8::{
14 DecodePrivateKey as _, DecodePublicKey as _, EncodePrivateKey as _, EncodePublicKey as _,
15 },
16 signature::{SignatureEncoding as _, Signer as _, Verifier as _},
17 traits::PublicKeyParts as _,
18};
19use serde::{
20 Deserialize, Deserializer, Serialize,
21 de::{DeserializeOwned, Visitor},
22};
23
24use crate::id;
25use crate::misc::time_ext;
26use crate::phcrypto;
27
28#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
32#[serde(transparent)]
33pub struct JWT {
34 inner: String,
35}
36
37#[derive(Debug, Clone, Default, serde::Serialize)]
39#[serde(transparent)]
40pub struct Claims {
41 inner: serde_json::Map<String, serde_json::Value>,
42}
43
44impl Claims {
45 pub fn new() -> Self {
46 Default::default()
47 }
48
49 pub fn check<'s, V: Deserialize<'s>>(
53 mut self,
54 name: &'static str,
55 expectation: impl FnOnce(&'static str, Option<V>) -> Result<(), Error>,
56 ) -> Result<Self, Error> {
57 let value: Option<V> = self
58 .inner
59 .remove(name)
60 .map(V::deserialize)
61 .transpose()
62 .map_err(|err| Error::DeserializingClaim {
63 claim_name: name,
64 source: err,
65 })?;
66
67 expectation(name, value)?;
68
69 Ok(self)
70 }
71
72 pub fn check_present_and<'s, V: Deserialize<'s>>(
75 self,
76 name: &'static str,
77 expectation: impl FnOnce(&'static str, V) -> Result<(), Error>,
78 ) -> Result<Self, Error> {
79 self.check(
80 name,
81 |claim_name: &'static str, v: Option<V>| -> Result<(), Error> {
82 if v.is_none() {
83 return Err(Error::MissingClaim(claim_name));
84 }
85 expectation(name, v.unwrap())
86 },
87 )
88 }
89
90 pub fn check_no(self, name: &'static str) -> Result<Self, Error> {
92 if self.inner.contains_key(name) {
93 return Err(Error::UnexpectedClaim(name));
94 }
95
96 Ok(self)
97 }
98
99 pub fn extract<V: DeserializeOwned>(&mut self, name: &'static str) -> Result<Option<V>, Error> {
101 let Some(json_value) = self.inner.remove(name) else {
102 return Ok(None);
103 };
104
105 let deserialized_value =
106 V::deserialize(json_value).map_err(|err| Error::DeserializingClaim {
107 claim_name: name,
108 source: err,
109 })?;
110
111 Ok(Some(deserialized_value))
112 }
113
114 pub fn ignore(mut self, name: &'static str) -> Self {
116 self.inner.remove(name);
117 self
118 }
119
120 pub fn check_iss(
122 self,
123 expectation: impl FnOnce(&'static str, Option<String>) -> Result<(), Error>,
124 ) -> Result<Self, Error> {
125 self.check("iss", expectation)
126 }
127
128 pub fn check_sub(
130 self,
131 expectation: impl FnOnce(&'static str, Option<String>) -> Result<(), Error>,
132 ) -> Result<Self, Error> {
133 self.check("sub", expectation)
134 }
135
136 pub fn default_check_timestamps(self) -> Result<Self, Error> {
139 let now = NumericDate::now();
140
141 self.check(
142 "iat",
143 |_claim_name: &'static str, _iat: Option<NumericDate>| -> Result<(), Error> {
144 Ok(())
146 },
147 )?
148 .check(
149 "exp",
150 |_claim_name: &'static str, exp: Option<NumericDate>| -> Result<(), Error> {
151 if let Some(exp) = exp
153 && exp < now
154 {
155 return Err(Error::Expired { when: exp });
156 }
157
158 Ok(())
159 },
160 )?
161 .check(
162 "nbf",
163 |_claim_name: &'static str, nbf: Option<NumericDate>| -> Result<(), Error> {
164 if let Some(nbf) = nbf
166 && now < nbf
167 {
168 return Err(Error::NotYetValid { valid_from: nbf });
169 }
170
171 Ok(())
172 },
173 )
174 }
175
176 pub fn default_check_common_claims(self) -> Result<Self, Error> {
179 self.default_check_timestamps()?
180 .check_no("iss")?
181 .check_no("sub")
182 }
183
184 pub fn visit_custom<C: DeserializeOwned, R>(
195 self,
196 visitor: impl FnOnce(C) -> R,
197 ) -> Result<R, Error> {
198 let self_ = self.default_check_common_claims()?;
200
201 let jso = serde_json::Value::Object(self_.inner);
202 let claims: C = C::deserialize(&jso).map_err(|err| {
203 let jso_str = serde_json::to_string_pretty(&jso).unwrap();
204
205 if let Err(better_err) = serde_json::from_str::<C>(&jso_str) {
206 return Error::DeserializingClaims {
207 source: better_err,
208 claims: jso_str,
209 };
210 }
211
212 log::error!("something fishy is going on here with this faulty json");
213 Error::DeserializingClaims {
214 source: err,
215 claims: "".to_string(),
216 }
217 })?;
218
219 Ok(visitor(claims))
220 }
221
222 pub fn into_custom<C: DeserializeOwned>(self) -> Result<C, Error> {
225 self.visit_custom(|c| c)
226 }
227
228 pub fn from_custom<C: Serialize>(claims: C) -> Result<Self, Error> {
230 let json_value = serde_json::to_value(claims).map_err(Error::SerializingClaims)?;
231
232 Ok(Self {
233 inner: match json_value {
234 serde_json::Value::Object(inner) => inner,
235 serde_json::Value::Null => {
236 return Err(Error::ClaimsDontSerializeToMapButNull {
237 claims_type: std::any::type_name::<C>(),
238 });
239 }
240 _ => {
241 return Err(Error::ClaimsDontSerializeToMap {
242 claims_type: std::any::type_name::<C>(),
243 });
244 }
245 },
246 })
247 }
248
249 pub fn claim<V: Serialize>(mut self, name: &'static str, value: V) -> Result<Self, Error> {
252 let old_value = self.inner.insert(
253 name.to_string(),
254 serde_json::to_value(value).map_err(|err| Error::SerializingClaim {
255 claim_name: name,
256 source: err,
257 })?,
258 );
259
260 if old_value.is_some() {
261 return Err(Error::ClaimAlreadyPresent(name));
262 }
263
264 Ok(self)
265 }
266
267 pub fn iat_now(self) -> Result<Self, Error> {
269 self.claim("iat", NumericDate::now())
270 }
271
272 pub fn exp_after(self, duration: std::time::Duration) -> Result<Self, Error> {
274 self.claim("exp", NumericDate::now() + duration)
275 }
276
277 pub fn nbf(self) -> Result<Self, Error> {
279 self.claim("nbf", NumericDate::now() - 30)
280 }
281
282 pub fn sign<SK: SigningKey>(&self, sk: &SK) -> Result<JWT, Error> {
284 JWT::create(&self.inner, sk)
285 }
286}
287
288#[derive(Serialize, Default, Clone, Copy, Eq, PartialEq, Debug, PartialOrd, Ord)]
302#[serde(transparent)]
303pub struct NumericDate {
304 timestamp: u64,
305}
306
307impl NumericDate {
308 pub fn new(timestamp: u64) -> Self {
311 Self { timestamp }
312 }
313
314 pub fn now() -> Self {
316 std::time::SystemTime::now().into()
317 }
318
319 pub fn timestamp(&self) -> u64 {
321 self.timestamp
322 }
323
324 pub fn date(&self) -> String {
326 let mut datetime = humantime::format_rfc3339(self.into()).to_string();
329
330 let Some(idx) = datetime.find('T') else {
331 panic!("bug: expected date returned by humantime to contain a 'T'");
332 };
333
334 datetime.truncate(idx);
335
336 datetime
337 }
338}
339
340impl From<std::time::SystemTime> for NumericDate {
341 fn from(st: std::time::SystemTime) -> Self {
342 Self::new(
343 st.duration_since(std::time::UNIX_EPOCH)
344 .expect("before unix epoch")
345 .as_secs(),
346 )
347 }
348}
349
350impl From<&NumericDate> for std::time::SystemTime {
351 fn from(nd: &NumericDate) -> Self {
352 std::time::UNIX_EPOCH + std::time::Duration::from_secs(nd.timestamp)
353 }
354}
355
356impl fmt::Display for NumericDate {
357 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
358 write!(f, "{}", time_ext::format_time(self.into()))
359 }
360}
361
362impl<'de> Deserialize<'de> for NumericDate {
363 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
364 d.deserialize_u64(NumericDateVisitor {})
365 }
366}
367
368struct NumericDateVisitor {}
370
371impl Visitor<'_> for NumericDateVisitor {
372 type Value = NumericDate;
373
374 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
375 write!(f, "a non-negative number")
376 }
377
378 fn visit_u64<E: serde::de::Error>(self, v: u64) -> Result<Self::Value, E> {
379 Ok(NumericDate::new(v))
380 }
381
382 fn visit_i64<E: serde::de::Error>(self, v: i64) -> Result<Self::Value, E> {
383 if v < 0 {
384 return Err(E::invalid_value(serde::de::Unexpected::Signed(v), &self));
385 }
386
387 self.visit_u64(v as u64)
388 }
389
390 fn visit_f64<E: serde::de::Error>(self, v: f64) -> Result<Self::Value, E> {
391 if v < 0.0 {
392 return Err(E::invalid_value(serde::de::Unexpected::Float(v), &self));
393 }
394
395 self.visit_u64(v as u64)
396 }
397
398 }
400
401impl core::ops::Add<std::time::Duration> for NumericDate {
402 type Output = Self;
403
404 fn add(self, duration: std::time::Duration) -> Self::Output {
405 self + duration.as_secs()
406 }
407}
408
409impl core::ops::Add<u64> for NumericDate {
410 type Output = Self;
411
412 fn add(mut self, secs: u64) -> Self::Output {
413 self.timestamp += secs;
414 self
415 }
416}
417
418impl core::ops::Sub<u64> for NumericDate {
419 type Output = Self;
420
421 fn sub(mut self, secs: u64) -> Self::Output {
422 self.timestamp -= secs;
423 self
424 }
425}
426
427impl From<String> for JWT {
428 fn from(s: String) -> Self {
429 Self { inner: s }
430 }
431}
432
433impl From<JWT> for String {
434 fn from(jwt: JWT) -> String {
435 jwt.inner
436 }
437}
438
439impl JWT {
440 pub fn create<C: Serialize, SK: SigningKey>(claims: &C, key: &SK) -> Result<JWT, Error> {
444 let to_be_signed: String = format!(
445 "{}.{}",
446 Base64UrlUnpadded::encode_string(
447 &serde_json::to_vec(&serde_json::json!({
448 "alg": SK::ALG,
449 }))
450 .map_err(Error::SerializingHeader)?
451 ),
452 &Base64UrlUnpadded::encode_string(
453 &serde_json::to_vec(claims).map_err(Error::SerializingClaims)?
454 )
455 );
456 Ok(JWT::from(format!(
457 "{}.{}",
458 to_be_signed,
459 Base64UrlUnpadded::encode_string(
460 key.sign(to_be_signed.as_bytes())
461 .map_err(Error::Signing)?
462 .as_ref()
463 )
464 )))
465 }
466
467 pub fn open<VK: VerifyingKey>(&self, key: &VK) -> Result<Claims, Error> {
473 let s = &self.inner;
474
475 let last_dot_pos: usize = s.rfind('.').ok_or(Error::MissingDot)?;
476 let signed: &str = &s[..last_dot_pos];
477 let first_dot_pos: usize = signed.find('.').ok_or(Error::MissingDot)?;
478
479 let header_vec: Vec<u8> =
481 Base64UrlUnpadded::decode_vec(&s[..first_dot_pos]).map_err(Error::InvalidBase64)?;
482
483 let header: Header =
484 serde_json::from_slice(&header_vec).map_err(Error::DeserializingHeader)?;
485
486 VK::check_alg(&header.alg)?;
487
488 let signature: Vec<u8> =
489 Base64UrlUnpadded::decode_vec(&s[last_dot_pos + 1..]).map_err(Error::InvalidBase64)?;
490
491 let claims_vec: Vec<u8> = Base64UrlUnpadded::decode_vec(&signed[first_dot_pos + 1..])
493 .map_err(Error::InvalidBase64)?;
494
495 let mut d = serde_json::Deserializer::from_slice(&claims_vec);
496
497 let claims = Claims {
498 inner: serde_json::Map::<String, serde_json::Value>::deserialize(&mut d)
499 .map_err(Error::ClaimsNotJsonMap)?,
500 };
501
502 if !key.is_valid_signature(signed.as_bytes(), signature) {
504 return Err(Error::InvalidSignature {
505 key: key.describe(),
506 claims,
507 });
508 }
509
510 Ok(claims)
511 }
512
513 pub fn as_str(&self) -> &str {
514 &self.inner
515 }
516
517 pub fn sha256(&self) -> sha2::Sha256 {
518 use sha2::Digest as _;
519 sha2::Sha256::new().chain_update(&self.inner)
520 }
521
522 pub fn id(&self) -> id::Id {
523 phcrypto::jwt_id(self)
524 }
525}
526
527impl fmt::Display for JWT {
528 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
529 write!(f, "{}", self.inner)
530 }
531}
532
533#[derive(thiserror::Error, Debug)]
534pub enum Error {
535 #[error("failed to serialize jwt header")]
536 SerializingHeader(#[source] serde_json::Error),
537
538 #[error("invalid jwt header")]
539 DeserializingHeader(#[source] serde_json::Error),
540
541 #[error("failed to serialize jwt claims")]
542 SerializingClaims(#[source] serde_json::Error),
543
544 #[error("failed to serialize claim {claim_name}")]
545 SerializingClaim {
546 claim_name: &'static str,
547 source: serde_json::Error,
548 },
549
550 #[error("claim {0} already present")]
551 ClaimAlreadyPresent(&'static str),
552
553 #[error("claims are not a valid json map")]
554 ClaimsNotJsonMap(#[source] serde_json::Error),
555
556 #[error("the given custom claims (of type {claims_type}) do not serialize to a json map")]
557 ClaimsDontSerializeToMap { claims_type: &'static str },
558
559 #[error(
560 "the given custom claims (of type {claims_type}) do not serialize to a json map, but to null. Hint: 'type Unit;' -> 'type Unit {{}}'"
561 )]
562 ClaimsDontSerializeToMapButNull { claims_type: &'static str },
563
564 #[error("invalid jwt claims: {source} in {claims}")]
565 DeserializingClaims {
566 source: serde_json::Error,
567 claims: String,
568 },
569
570 #[error("failed to deserialize claim {claim_name}")]
571 DeserializingClaim {
572 claim_name: &'static str,
573 source: serde_json::Error,
574 },
575
576 #[error("jwt contains unexpected/unhandled claim `{0}`")]
577 UnexpectedClaim(&'static str),
578
579 #[error("jwt is missing the claim `{0}'")]
580 MissingClaim(&'static str),
581
582 #[error("the claim `{claim_name}` is invalid")]
583 InvalidClaim {
584 claim_name: &'static str,
585 source: anyhow::Error,
586 },
587
588 #[error("expired at {when}")]
589 Expired { when: NumericDate },
590
591 #[error("only valid after {valid_from}")]
592 NotYetValid { valid_from: NumericDate },
593
594 #[error("signing jwt failed")]
595 Signing(#[source] anyhow::Error),
596
597 #[error("missing dot (.) in jwt (there should be two dots)")]
598 MissingDot,
599
600 #[error("jwt contains invalid unpadded urlsafe base64")]
601 InvalidBase64(#[source] base64ct::Error),
602
603 #[error("jwt signature is not valid (for this key, {key})")]
604 InvalidSignature { key: String, claims: Claims },
605
606 #[error("unexpected algorithm; got {got}, but expected {expected}")]
607 UnexpectedAlgorithm { got: String, expected: &'static str },
608}
609
610pub fn sign<SK: SigningKey>(claims: &impl Serialize, key: &SK) -> anyhow::Result<String> {
613 Ok(JWT::create(claims, key)?.inner)
614}
615
616pub fn get_current_timestamp() -> u64 {
619 std::time::SystemTime::now()
620 .duration_since(std::time::UNIX_EPOCH)
621 .expect("system clock reports a time before the Unix epoch")
622 .as_secs()
623}
624
625#[derive(Serialize, Deserialize, Debug)]
627#[serde(deny_unknown_fields)]
628struct Header<'a> {
629 #[serde(
630 rename = "typ",
631 skip_serializing, default )]
634 _typ: HeaderType,
635
636 #[serde(borrow)]
637 alg: Cow<'a, str>,
638 }
640
641#[derive(Default, Debug)]
643struct HeaderType {}
644
645impl<'de> Deserialize<'de> for HeaderType {
646 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
647 d.deserialize_str(HeaderType {})
648 }
649}
650
651impl Visitor<'_> for HeaderType {
652 type Value = Self;
653
654 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
655 write!(f, "the string \"JWT\" as \"typ\"")
656 }
657
658 fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
659 if "JWT".eq_ignore_ascii_case(v) {
660 return Ok(self);
661 }
662
663 Err(E::invalid_value(serde::de::Unexpected::Str(v), &self))
664 }
665}
666
667pub trait SigningKey: Key {
669 type Signature: AsRef<[u8]>;
671
672 fn sign(&self, s: &[u8]) -> anyhow::Result<Self::Signature>;
674
675 fn jwk(&self) -> serde_json::Value;
677}
678
679pub trait VerifyingKey: Key {
681 fn is_valid_signature(&self, message: &[u8], signature: Vec<u8>) -> bool;
683
684 fn describe(&self) -> String;
686}
687
688pub trait Key {
690 const ALG: &'static str;
692
693 fn check_alg(alg: &str) -> Result<(), Error> {
697 if alg == Self::ALG {
698 return Ok(());
699 }
700 Err(Error::UnexpectedAlgorithm {
701 got: alg.to_string(),
702 expected: Self::ALG,
703 })
704 }
705}
706
707pub struct IgnoreSignature;
711
712impl Key for IgnoreSignature {
713 const ALG: &'static str = "WARNING! This should never appear in the 'alg' field of a JWT.";
714
715 fn check_alg(_alg: &str) -> Result<(), Error> {
716 Ok(())
717 }
718}
719
720impl VerifyingKey for IgnoreSignature {
721 fn is_valid_signature(&self, _message: &[u8], _signature: Vec<u8>) -> bool {
722 true
723 }
724
725 fn describe(&self) -> String {
726 "n/a".into()
727 }
728}
729
730impl SigningKey for ed25519_dalek::SigningKey {
766 type Signature = [u8; 64]; fn sign(&self, s: &[u8]) -> anyhow::Result<[u8; 64]> {
769 Ok(ed25519_dalek::Signer::sign(self, s).to_bytes())
770 }
771
772 fn jwk(&self) -> serde_json::Value {
773 serde_json::json!({
774 "kty": "OKP", "alg": Self::ALG,
776 "crv": "Ed25519",
777 "x": Base64UrlUnpadded::encode_string(AsRef::<ed25519_dalek::VerifyingKey>::as_ref(self).as_bytes()),
778 "use": "sig",
780 })
781 }
782}
783
784impl Key for ed25519_dalek::SigningKey {
785 const ALG: &'static str = "EdDSA";
786}
787
788impl Key for ed25519_dalek::VerifyingKey {
789 const ALG: &'static str = "EdDSA";
790}
791
792impl VerifyingKey for ed25519_dalek::VerifyingKey {
793 fn is_valid_signature(&self, message: &[u8], signature: Vec<u8>) -> bool {
794 if let Ok(signature) = ed25519_dalek::Signature::from_slice(&signature) {
795 return ed25519_dalek::Verifier::verify(self, message, &signature).is_ok();
796 }
797 false
798 }
799
800 fn describe(&self) -> String {
801 base16ct::lower::encode_string(self.as_bytes().as_slice())
802 }
803}
804
805#[derive(
807 serde::Serialize, serde::Deserialize, Clone, Debug, Eq, PartialEq, zeroize::ZeroizeOnDrop,
808)]
809#[serde(transparent)]
810pub struct HS256(#[serde(with = "serde_bytes")] pub Vec<u8>);
811
812impl SigningKey for HS256 {
824 type Signature = sha2::digest::generic_array::GenericArray<
825 u8,
826 <sha2::Sha256 as sha2::digest::OutputSizeUser>::OutputSize,
827 >;
828
829 fn sign(&self, s: &[u8]) -> anyhow::Result<Self::Signature> {
830 let mut mac = hmac::Hmac::<sha2::Sha256>::new_from_slice(&self.0)?;
831 mac.update(s);
832 Ok(mac.finalize().into_bytes())
833 }
834
835 fn jwk(&self) -> serde_json::Value {
836 panic!("HS256 has no public key to describe using JWK");
837 }
838}
839
840impl VerifyingKey for HS256 {
849 fn is_valid_signature(&self, message: &[u8], signature: Vec<u8>) -> bool {
850 let mut mac = hmac::Hmac::<sha2::Sha256>::new_from_slice(&self.0)
851 .expect("expect a sha256 mac to accept a key of any size");
852 mac.update(message);
853 mac.verify_slice(&signature).is_ok()
854 }
855
856 fn describe(&self) -> String {
857 base16ct::lower::encode_string(&self.0)
858 }
859}
860
861impl Key for HS256 {
862 const ALG: &'static str = "HS256";
863}
864
865#[derive(Clone, Debug)]
871pub struct RS256Vk(rsa::pkcs1v15::VerifyingKey<sha2::Sha256>);
872
873impl RS256Vk {
874 pub fn new(pk: rsa::RsaPublicKey) -> Self {
875 Self(rsa::pkcs1v15::VerifyingKey::<sha2::Sha256>::new(pk))
877 }
878
879 pub fn from_public_key_pem(pem: &str) -> anyhow::Result<Self> {
880 Ok(Self(
881 rsa::pkcs1v15::VerifyingKey::<sha2::Sha256>::from_public_key_pem(pem)?,
882 ))
883 }
884
885 pub fn to_public_key_pem(&self) -> anyhow::Result<String> {
886 Ok(self.0.to_public_key_pem(Default::default())?)
887 }
888
889 pub fn as_rsa_pk(&self) -> &rsa::RsaPublicKey {
891 AsRef::<rsa::RsaPublicKey>::as_ref(&self.0)
892 }
893}
894
895impl PartialEq for RS256Vk {
899 fn eq(&self, other: &Self) -> bool {
900 self.as_rsa_pk() == other.as_rsa_pk()
902 }
903}
904
905impl Eq for RS256Vk {}
907
908impl Key for RS256Vk {
909 const ALG: &'static str = "RS256";
910}
911
912impl VerifyingKey for RS256Vk {
913 fn is_valid_signature(&self, message: &[u8], signature: Vec<u8>) -> bool {
914 let signature: rsa::pkcs1v15::Signature = match signature.as_slice().try_into() {
915 Ok(signature) => signature,
916 Err(_) => return false,
917 };
918
919 self.0.verify(message, &signature).is_ok()
920 }
921
922 fn describe(&self) -> String {
923 format!("{self:?}")
924 }
925}
926
927#[derive(Clone, Debug)]
929pub struct RS256Sk(rsa::pkcs1v15::SigningKey<sha2::Sha256>);
930
931impl PartialEq for RS256Sk {
932 fn eq(&self, other: &Self) -> bool {
933 self.as_rsa_priv() == other.as_rsa_priv()
934 }
935}
936
937impl Eq for RS256Sk {}
938
939impl Key for RS256Sk {
940 const ALG: &'static str = RS256Vk::ALG;
941}
942
943impl SigningKey for RS256Sk {
944 type Signature = Box<[u8]>;
945
946 fn sign(&self, s: &[u8]) -> anyhow::Result<Self::Signature> {
947 Ok(self.0.sign(s).to_bytes())
948 }
949
950 fn jwk(&self) -> serde_json::Value {
951 let rsa_pub: &rsa::RsaPublicKey = self.as_rsa_pub();
952
953 serde_json::json!({
954 "kty": "RSA",
955 "alg": Self::ALG,
956 "mod": Base64UrlUnpadded::encode_string(&rsa_pub.n().to_bytes_be()),
957 "exp": Base64UrlUnpadded::encode_string(&rsa_pub.e().to_bytes_be()),
958 })
959 }
960}
961
962impl RS256Sk {
963 pub fn new(pk: rsa::RsaPrivateKey) -> Self {
964 Self(rsa::pkcs1v15::SigningKey::<sha2::Sha256>::new(pk))
965 }
966
967 pub fn random(bit_size: usize) -> anyhow::Result<Self> {
968 Ok(Self::new(rsa::RsaPrivateKey::new(
969 &mut rsa::rand_core::OsRng,
970 bit_size,
971 )?))
972 }
973
974 pub fn from_pkcs8_pem(pem: &str) -> anyhow::Result<Self> {
975 Ok(Self(
976 rsa::pkcs1v15::SigningKey::<sha2::Sha256>::from_pkcs8_pem(pem)?,
977 ))
978 }
979
980 pub fn to_pkcs8_pem(&self) -> anyhow::Result<zeroize::Zeroizing<String>> {
981 Ok(self.0.to_pkcs8_pem(Default::default())?)
982 }
983
984 pub fn as_rsa_priv(&self) -> &rsa::RsaPrivateKey {
985 AsRef::<rsa::RsaPrivateKey>::as_ref(&self.0)
986 }
987
988 pub fn as_rsa_pub(&self) -> &rsa::RsaPublicKey {
989 AsRef::<rsa::RsaPublicKey>::as_ref(self.as_rsa_priv())
990 }
991}
992
993pub mod expecting {
995 use super::*;
996
997 pub fn exactly<T>(
999 what: &T,
1000 ) -> impl (FnOnce(&'static str, Option<T::Owned>) -> Result<(), Error>) + use<'_, T>
1001 where
1002 T: std::fmt::Debug + PartialEq + ToOwned + ?Sized,
1003 {
1004 move |claim_name: &'static str, val_maybe: Option<T::Owned>| {
1005 if let Some(val) = val_maybe {
1006 if *what == *val.borrow() {
1007 return Ok(());
1008 }
1009 return Err(Error::InvalidClaim {
1010 claim_name,
1011 source: anyhow::anyhow!("expected {:?}; got {:?}", what, val.borrow()),
1012 });
1013 }
1014 Err(Error::MissingClaim(claim_name))
1015 }
1016 }
1017}
1018
1019#[cfg(test)]
1020mod tests {
1021 use super::*;
1022
1023 #[test]
1024 fn test_jwt() {
1025 let jwt: JWT = serde_json::from_str("\"eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk\"").unwrap();
1026
1027 let key = HS256(
1028 base64ct::Base64UrlUnpadded::decode_vec("AyM1SysPpbyDfgZld3umj1qzKObwVMkoqQ-EstJQLr_T-1qS0gZH75aKtMN3Yj0iPS4hcgUuTwjAzZr1Z9CAow").unwrap(),
1029 );
1030
1031 let claims = jwt.open(&key).unwrap();
1032
1033 assert!(
1034 claims
1035 .clone()
1036 .into_custom::<serde_json::Value>()
1037 .unwrap_err()
1038 .to_string()
1039 .starts_with("expired at 2011-03-22T18:43:00Z (")
1040 );
1041
1042 assert_eq!(
1043 &claims
1044 .clone()
1045 .ignore("exp")
1046 .into_custom::<serde_json::Value>()
1047 .unwrap_err()
1048 .to_string(),
1049 "jwt contains unexpected/unhandled claim `iss`"
1050 );
1051
1052 #[derive(Deserialize, PartialEq, Eq, Debug)]
1053 #[serde(deny_unknown_fields)]
1054 struct Custom {
1055 #[serde(rename = "http://example.com/is_root")]
1056 is_root: bool,
1057 }
1058
1059 assert_eq!(
1060 claims
1061 .clone()
1062 .ignore("exp")
1063 .check_iss(
1064 |_claim_name: &'static str, iss: Option<String>| -> Result<(), Error> {
1065 assert_eq!(iss, Some("joe".to_string()));
1066 Ok(())
1067 }
1068 )
1069 .unwrap()
1070 .into_custom::<Custom>()
1071 .unwrap(),
1072 Custom { is_root: true }
1073 );
1074 }
1075
1076 #[test]
1077 fn test_header() {
1078 assert_eq!(
1080 serde_json::from_str::<Header>(r#"{}"#)
1081 .unwrap_err()
1082 .to_string(),
1083 "missing field `alg` at line 1 column 2".to_string()
1084 );
1085
1086 assert_eq!(
1088 serde_json::from_str::<Header>(r#"{"typ": "not JWT", "alg": ""}"#)
1089 .unwrap_err()
1090 .to_string(),
1091 "invalid value: string \"not JWT\", expected the string \"JWT\" as \"typ\" at line 1 column 17".to_string()
1092 );
1093
1094 assert_eq!(
1095 serde_json::from_str::<Header>(r#"{"typ": 12,"alg":""}"#)
1096 .unwrap_err()
1097 .to_string(),
1098 "invalid type: integer `12`, expected the string \"JWT\" as \"typ\" at line 1 column 10".to_string()
1099 );
1100
1101 assert!(serde_json::from_str::<Header>(r#"{"typ": "jWT","alg":""}"#).is_ok());
1103
1104 let header_a: Header = serde_json::from_str(r#"{"alg":"borrowed"}"#).unwrap();
1106 let header_b: Header = serde_json::from_str(r#"{"alg":"owned\u0020"}"#).unwrap();
1107
1108 assert!(matches!(header_a.alg, Cow::Borrowed(_)));
1109 assert!(matches!(header_b.alg, Cow::Owned(_)));
1110
1111 assert_eq!(
1113 serde_json::from_str::<Header>(r#"{"alg":"", "unknown_field": ""}"#)
1114 .unwrap_err()
1115 .to_string(),
1116 "unknown field `unknown_field`, expected `typ` or `alg` at line 1 column 26"
1117 .to_string()
1118 );
1119 }
1120
1121 #[test]
1122 fn test_numericdate() {
1123 assert!(NumericDate::deserialize(serde_json::json!(0u64)).is_ok());
1124 assert!(NumericDate::deserialize(serde_json::json!(0f64)).is_ok());
1125 assert!(NumericDate::deserialize(serde_json::json!(0f32)).is_ok());
1126 assert!(NumericDate::deserialize(serde_json::json!(i64::MIN)).is_err());
1127 assert!(NumericDate::deserialize(serde_json::json!(f32::MIN)).is_err());
1128 assert!(NumericDate::deserialize(serde_json::json!(f64::MIN)).is_err());
1129 assert_eq!(
1130 NumericDate::deserialize(serde_json::json!(1.9))
1131 .unwrap()
1132 .timestamp,
1133 1
1134 );
1135 }
1136
1137 #[test]
1138 fn test_rs256() {
1139 let sk = RS256Sk::new(
1141 rsa::RsaPrivateKey::from_components(
1142 rsa::BigUint::from_bytes_be(
1144 &base64ct::Base64UrlUnpadded::decode_vec(concat!(
1145 "ofgWCuLjybRlzo0tZWJjNiuSfb4p4fAkd_wWJcyQoTbji9k0l8W26mPddx",
1146 "HmfHQp-Vaw-4qPCJrcS2mJPMEzP1Pt0Bm4d4QlL-yRT-SFd2lZS-pCgNMs",
1147 "D1W_YpRPEwOWvG6b32690r2jZ47soMZo9wGzjb_7OMg0LOL-bSf63kpaSH",
1148 "SXndS5z5rexMdbBYUsLA9e-KXBdQOS-UTo7WTBEMa2R2CapHg665xsmtdV",
1149 "MTBQY4uDZlxvb3qCo5ZwKh9kG4LT6_I5IhlJH7aGhyxXFvUK-DWNmoudF8",
1150 "NAco9_h9iaGNj8q2ethFkMLs91kzk2PAcDTW9gb54h4FRWyuXpoQ",
1151 ))
1152 .unwrap(),
1153 ),
1154 rsa::BigUint::from_bytes_be(
1156 &base64ct::Base64UrlUnpadded::decode_vec("AQAB").unwrap(),
1157 ),
1158 rsa::BigUint::from_bytes_be(
1160 &base64ct::Base64UrlUnpadded::decode_vec(concat!(
1161 "Eq5xpGnNCivDflJsRQBXHx1hdR1k6Ulwe2JZD50LpXyWPEAeP88vLNO97I",
1162 "jlA7_GQ5sLKMgvfTeXZx9SE-7YwVol2NXOoAJe46sui395IW_GO-pWJ1O0",
1163 "BkTGoVEn2bKVRUCgu-GjBVaYLU6f3l9kJfFNS3E0QbVdxzubSu3Mkqzjkn",
1164 "439X0M_V51gfpRLI9JYanrC4D4qAdGcopV_0ZHHzQlBjudU2QvXt4ehNYT",
1165 "CBr6XCLQUShb1juUO1ZdiYoFaFQT5Tw8bGUl_x_jTj3ccPDVZFD9pIuhLh",
1166 "BOneufuBiB4cS98l2SR_RQyGWSeWjnczT0QU91p1DhOVRuOopznQ",
1167 ))
1168 .unwrap(),
1169 ),
1170 vec![
1172 rsa::BigUint::from_bytes_be(
1174 &base64ct::Base64UrlUnpadded::decode_vec(concat!(
1175 "4BzEEOtIpmVdVEZNCqS7baC4crd0pqnRH_5IB3jw3bcxGn6QLvnEtfdUdi",
1176 "YrqBdss1l58BQ3KhooKeQTa9AB0Hw_Py5PJdTJNPY8cQn7ouZ2KKDcmnPG",
1177 "BY5t7yLc1QlQ5xHdwW1VhvKn-nXqhJTBgIPgtldC-KDV5z-y2XDwGUc",
1178 ))
1179 .unwrap(),
1180 ),
1181 rsa::BigUint::from_bytes_be(
1183 &base64ct::Base64UrlUnpadded::decode_vec(concat!(
1184 "uQPEfgmVtjL0Uyyx88GZFF1fOunH3-7cepKmtH4pxhtCoHqpWmT8YAmZxa",
1185 "ewHgHAjLYsp1ZSe7zFYHj7C6ul7TjeLQeZD_YwD66t62wDmpe_HlB-TnBA",
1186 "-njbglfIsRLtXlnDzQkv5dTltRJ11BKBBypeeF6689rjcJIDEz9RWdc",
1187 ))
1188 .unwrap(),
1189 ),
1190 ],
1191 )
1192 .unwrap(),
1193 );
1194
1195 let to_sign: &str = concat!(
1196 "eyJhbGciOiJSUzI1NiJ9",
1197 ".",
1198 "eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFt",
1199 "cGxlLmNvbS9pc19yb290Ijp0cnVlfQ",
1200 );
1201
1202 let signature = sk.sign(&to_sign.as_bytes()).unwrap();
1203
1204 assert_eq!(
1205 signature.as_ref(),
1206 &Base64UrlUnpadded::decode_vec(concat!(
1207 "cC4hiUPoj9Eetdgtv3hF80EGrhuB__dzERat0XF9g2VtQgr9PJbu3XOiZj5RZmh7",
1208 "AAuHIm4Bh-0Qc_lF5YKt_O8W2Fp5jujGbds9uJdbF9CUAr7t1dnZcAcQjbKBYNX4",
1209 "BAynRFdiuB--f_nZLgrnbyTyWzO75vRK5h6xBArLIARNPvkSjtQBMHlb1L07Qe7K",
1210 "0GarZRmB_eSN9383LcOLn6_dO--xi12jzDwusC-eOkHWEsqtFZESc6BfI7noOPqv",
1211 "hJ1phCnvWh6IeYI2w9QOYEUipUTI8np6LbgGY9Fs98rqVt5AXLIhWkWywlVmtVrB",
1212 "p0igcN_IoypGlUPQGe77Rw",
1213 ))
1214 .unwrap()
1215 );
1216
1217 let jwt: JWT = concat!(
1218 "eyJhbGciOiJSUzI1NiJ9",
1219 ".",
1220 "eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFt",
1221 "cGxlLmNvbS9pc19yb290Ijp0cnVlfQ",
1222 ".",
1223 "cC4hiUPoj9Eetdgtv3hF80EGrhuB__dzERat0XF9g2VtQgr9PJbu3XOiZj5RZmh7",
1224 "AAuHIm4Bh-0Qc_lF5YKt_O8W2Fp5jujGbds9uJdbF9CUAr7t1dnZcAcQjbKBYNX4",
1225 "BAynRFdiuB--f_nZLgrnbyTyWzO75vRK5h6xBArLIARNPvkSjtQBMHlb1L07Qe7K",
1226 "0GarZRmB_eSN9383LcOLn6_dO--xi12jzDwusC-eOkHWEsqtFZESc6BfI7noOPqv",
1227 "hJ1phCnvWh6IeYI2w9QOYEUipUTI8np6LbgGY9Fs98rqVt5AXLIhWkWywlVmtVrB",
1228 "p0igcN_IoypGlUPQGe77Rw",
1229 )
1230 .to_string()
1231 .into();
1232
1233 let pk = RS256Vk::new(sk.as_rsa_pub().clone());
1234
1235 let _ = jwt.open(&pk).unwrap();
1236 }
1237}