Skip to main content

pubhubs/misc/
serde_ext.rs

1//! Tools for (de)serialization
2use serde::{
3    Deserialize, Deserializer, Serialize, Serializer, de::IntoDeserializer as _, ser::Error as _,
4};
5
6use core::fmt;
7use std::marker::PhantomData;
8
9/// Deserializes nothing, useful for ignoring deprecated fields in types annotated with
10/// `#[serde(deny_unknown_fields)]`.
11#[derive(serde::Serialize, Debug, Clone, Copy, PartialEq, Eq)]
12pub struct Skip;
13
14impl<'de> Deserialize<'de> for Skip {
15    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
16    where
17        D: serde::de::Deserializer<'de>,
18    {
19        Ok(Self {})
20    }
21}
22
23/// Deserializes an empty (json) object to a type `T`.  Panics if this is not possible.
24pub fn default_object<T: serde::de::DeserializeOwned>() -> T {
25    serde_json::from_value(serde_json::Value::Object(Default::default())).unwrap()
26}
27
28pub mod bytes_wrapper {
29    use super::*;
30
31    /// Wraps a type `T` that uses the byte array serde data type for serialization so
32    /// that the serde string data type is used instead,
33    /// according to [BytesEncoding] `O::Encoding`.
34    ///
35    /// Due to the generic (de)serialize implementation, the standard types
36    /// `Vec<u8>`, `&[u8]`, `[u8,N]`, ... use the sequence serde data type instead of byte array.
37    ///
38    /// Use [serde_bytes::Bytes], [serde_bytes::ByteBuf], and [ByteArray] instead.
39    /// (Or use [`ChangeVisitorType`] to change the visitor type to [`VisitorType::ByteSequence`].)
40    ///
41    /// We primarily use this to encode keys as hex or base64 strings in JSON instead of arrays.
42    pub struct BytesWrapper<T, O> {
43        inner: T,
44        phantom: PhantomData<O>,
45    }
46
47    // Implement some traits that are not derived correctly due to the presence of `O`.
48    impl<T: Copy, O> Copy for BytesWrapper<T, O> {}
49
50    impl<T: Clone, O> Clone for BytesWrapper<T, O> {
51        fn clone(&self) -> Self {
52            Self {
53                inner: self.inner.clone(),
54                phantom: PhantomData,
55            }
56        }
57    }
58
59    impl<T: fmt::Debug, O> fmt::Debug for BytesWrapper<T, O> {
60        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
61            self.inner.fmt(f)
62        }
63    }
64
65    impl<T: PartialEq, O> PartialEq for BytesWrapper<T, O> {
66        fn eq(&self, other: &Self) -> bool {
67            self.inner.eq(&other.inner)
68        }
69    }
70
71    impl<T: Eq, O> Eq for BytesWrapper<T, O> {}
72
73    impl<T: std::hash::Hash, O> std::hash::Hash for BytesWrapper<T, O> {
74        fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
75            self.inner.hash(state)
76        }
77
78        // NOTE: we can't implement the provided mehtod `hash_slice` by forwarding
79        // it to T::hash_slice, because we cannot create a `&[T]` from a `&[Self]`
80        // without copying.
81    }
82
83    impl<T: PartialOrd, O> PartialOrd for BytesWrapper<T, O> {
84        fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
85            self.inner.partial_cmp(&other.inner)
86        }
87
88        fn lt(&self, other: &Self) -> bool {
89            self.inner.lt(&other.inner)
90        }
91
92        fn le(&self, other: &Self) -> bool {
93            self.inner.le(&other.inner)
94        }
95
96        fn gt(&self, other: &Self) -> bool {
97            self.inner.gt(&other.inner)
98        }
99
100        fn ge(&self, other: &Self) -> bool {
101            self.inner.ge(&other.inner)
102        }
103    }
104
105    impl<T: Ord, O> Ord for BytesWrapper<T, O> {
106        fn cmp(&self, other: &Self) -> std::cmp::Ordering {
107            self.inner.cmp(&other.inner)
108        }
109
110        // NOTE: we can't more efficiently implement `max`, `min` and `clamp` by forwarding to the
111        // implementations on `T`.
112    }
113
114    impl<T: Default, O> Default for BytesWrapper<T, O> {
115        fn default() -> Self {
116            Self {
117                inner: T::default(),
118                phantom: PhantomData,
119            }
120        }
121    }
122
123    /// Determines how exactly [`BytesWrapper`] should wrap the underlying type.
124    pub trait Options {
125        /// How this type is encoded (e.g. base16, base64, etc.)
126        type Encoding;
127
128        /// During deserialization, how should the byte array be visited?
129        const VISITOR_TYPE: VisitorType = VisitorType::default();
130    }
131
132    impl<E> Options for (E,)
133    where
134        E: BytesEncoding,
135    {
136        type Encoding = E;
137    }
138
139    /// Changes the [`VisitorType`] of the given [`Options`] to `VT`.
140    pub struct ChangeVisitorType<O, const VT: isize> {
141        phantom_o: PhantomData<O>,
142    }
143
144    impl<O, const VT: isize> Options for ChangeVisitorType<O, VT>
145    where
146        O: Options,
147    {
148        type Encoding = O::Encoding;
149
150        const VISITOR_TYPE: VisitorType = match VT {
151            VT_OWNED_BYTE_ARRAY => VisitorType::OwnedByteArray,
152            VT_BORROWED_BYTE_ARRAY => VisitorType::BorrowedByteArray,
153            VT_TRANSIENT_BYTE_ARRAY => VisitorType::TransientByteArray,
154            VT_BYTE_SEQUENCE => VisitorType::ByteSequence,
155            _ => panic!("Unknown visitor type"),
156        };
157    }
158
159    impl<T, O> From<T> for BytesWrapper<T, O> {
160        fn from(inner: T) -> Self {
161            Self {
162                inner,
163                phantom: PhantomData,
164            }
165        }
166    }
167
168    impl<T, O> BytesWrapper<T, O> {
169        /// Returns the wrapped object.
170        ///
171        /// Note:  We cannot implement `Into<T>` for [BytesWrapper], because it would clash
172        /// with the implementation of `Into<T>` when `T` implements `From<BytesWrapper>`.
173        pub fn into_inner(self) -> T {
174            self.inner
175        }
176
177        pub fn new(inner: T) -> Self {
178            inner.into()
179        }
180    }
181
182    const VT_OWNED_BYTE_ARRAY: isize = 0;
183    const VT_BORROWED_BYTE_ARRAY: isize = 1;
184    const VT_TRANSIENT_BYTE_ARRAY: isize = 2;
185    const VT_BYTE_SEQUENCE: isize = 3;
186
187    /// Enumerates the ways in which a sequence of bytes may be visited during deserialization.
188    #[repr(isize)]
189    pub enum VisitorType {
190        /// [serde::de::Visitor::visit_byte_buf], default
191        OwnedByteArray = VT_OWNED_BYTE_ARRAY,
192        /// [serde::de::Visitor::visit_borrowed_bytes]
193        BorrowedByteArray = VT_BORROWED_BYTE_ARRAY,
194        /// [serde::de::Visitor::visit_bytes]
195        TransientByteArray = VT_TRANSIENT_BYTE_ARRAY,
196        /// [serde::de::Visitor::visit_seq]
197        ByteSequence = VT_BYTE_SEQUENCE,
198    }
199
200    impl VisitorType {
201        const fn default() -> Self {
202            VisitorType::OwnedByteArray
203        }
204    }
205
206    /// Trait for specifying the encoding of bytes as strings, like hex or base64.
207    pub trait BytesEncoding {
208        type Error: std::error::Error;
209
210        /// Encodes `src` into `dst`, returning the slice of `dst` that was written.
211        ///
212        /// The caller must ensure that `len(dst) >= encoded_len(src).unwrap()`.
213        fn encode<'a>(src: &[u8], dst: &'a mut str) -> Result<&'a str, Self::Error>;
214
215        /// Decodes `src` into `dst`, returning the slice of `dst` that was written.
216        ///
217        /// The caller must ensure that `len(dst) >= decoded_len(src).unwrap()`.
218        fn decode<'a>(src: &str, dst: &'a mut [u8]) -> Result<&'a [u8], Self::Error>;
219
220        /// See [Self::encode].
221        fn encoded_len(bytes: &[u8]) -> Result<usize, Self::Error>;
222
223        /// See [Self::decode].
224        fn decoded_len(bytes: &str) -> Result<usize, Self::Error>;
225    }
226
227    /// Hex [BytesEncoding].
228    pub struct B16Encoding<
229        const ENCODE_LOWER_CASE: bool = { true },
230        const DECODE_MIXED_CASE: bool = { true },
231    > {}
232
233    /// Wrapper around `T` implementing (de)serialization using hex-encoding.
234    pub type B16<
235        T = serde_bytes::ByteBuf,
236        const ENCODE_LOWER_CASE: bool = true,
237        const DECODE_MIXED_CASE: bool = true,
238    > = BytesWrapper<T, (B16Encoding<ENCODE_LOWER_CASE, DECODE_MIXED_CASE>,)>;
239
240    impl<const ELC: bool, const DMC: bool> BytesEncoding for B16Encoding<ELC, DMC> {
241        type Error = base16ct::Error;
242
243        fn encode<'a>(src: &[u8], dst: &'a mut str) -> Result<&'a str, Self::Error> {
244            let dst: &'a mut [u8] = unsafe { dst.as_bytes_mut() };
245            // SAFETY: hex characters are valid utf8
246
247            if ELC {
248                base16ct::lower::encode_str(src, dst)
249            } else {
250                base16ct::upper::encode_str(src, dst)
251            }
252        }
253
254        fn decode<'a>(src: &str, dst: &'a mut [u8]) -> Result<&'a [u8], Self::Error> {
255            let src: &[u8] = src.as_bytes();
256
257            if DMC {
258                base16ct::mixed::decode(src, dst)
259            } else if ELC {
260                base16ct::lower::decode(src, dst)
261            } else {
262                base16ct::upper::decode(src, dst)
263            }
264        }
265
266        fn encoded_len(bytes: &[u8]) -> Result<usize, Self::Error> {
267            if bytes.len() >= usize::MAX / 2 {
268                Err(base16ct::Error::InvalidLength)
269            } else {
270                Ok(base16ct::encoded_len(bytes))
271            }
272        }
273
274        fn decoded_len(bytes: &str) -> Result<usize, Self::Error> {
275            base16ct::decoded_len(bytes.as_bytes())
276        }
277    }
278
279    /// Base64 [BytesEncoding]
280    pub struct B64Encoding<Enc: base64ct::Encoding> {
281        phantom: PhantomData<Enc>,
282    }
283
284    /// Wrapper around `T` implementing (de)serialization using [base64ct::Base64].
285    pub type B64<T = serde_bytes::ByteBuf> = BytesWrapper<T, (B64Encoding<base64ct::Base64>,)>;
286
287    /// Wrapper around `T` implementing (de)serialization using [base64ct::Base64UrlUnpadded].
288    pub type B64UU<T = serde_bytes::ByteBuf> =
289        BytesWrapper<T, (B64Encoding<base64ct::Base64UrlUnpadded>,)>;
290
291    impl<Enc: base64ct::Encoding> BytesEncoding for B64Encoding<Enc> {
292        type Error = base64ct::Error;
293
294        fn encode<'a>(src: &[u8], dst: &'a mut str) -> Result<&'a str, Self::Error> {
295            // SAFETY: all the base64ct alphabets are valid utf-8
296            Enc::encode(src, unsafe { dst.as_bytes_mut() }).map_err(Into::into)
297        }
298
299        fn decode<'a>(src: &str, dst: &'a mut [u8]) -> Result<&'a [u8], Self::Error> {
300            Enc::decode(src, dst)
301        }
302
303        fn encoded_len(bytes: &[u8]) -> Result<usize, Self::Error> {
304            if bytes.len() >= usize::MAX / 4 {
305                Err(base64ct::Error::InvalidLength)
306            } else {
307                Ok(Enc::encoded_len(bytes))
308            }
309        }
310
311        fn decoded_len(bytes: &str) -> Result<usize, Self::Error> {
312            // NOTE: base64ct provides no `decoded_len` function, so we overestimate
313            // the decoded length as the original length
314            Ok(bytes.len())
315        }
316    }
317
318    impl<T, O> std::ops::Deref for BytesWrapper<T, O> {
319        type Target = T;
320
321        fn deref(&self) -> &Self::Target {
322            &self.inner
323        }
324    }
325
326    impl<T, O> std::ops::DerefMut for BytesWrapper<T, O> {
327        fn deref_mut(&mut self) -> &mut Self::Target {
328            &mut self.inner
329        }
330    }
331
332    /// Contains implementation details for [`EncodingSerializer`].
333    mod encoding_serializer {
334        use super::*;
335
336        macro_rules! expected_bytes {
337            ($got : tt) => {
338                Err(Self::Error::custom(ExpectedBytesError {
339                    got: stringify!($got),
340                }))
341            };
342        }
343
344        macro_rules! serialize_primitives {
345            ($($f: ident: $t:ty,)*) => {
346                $(
347                    fn $f(self, _v:$t) -> Result<Self::Ok, Self::Error> {
348                        expected_bytes!($t)
349                    }
350                )*
351            }
352        }
353
354        /// Serializes a byte array by encoding it according to [BytesEncoding] `E` and passing
355        /// the resulting string to the [Serializer] `S`.
356        pub(super) struct EncodingSerializer<S, E> {
357            s: S,
358            phantom: PhantomData<E>,
359        }
360
361        impl<S, E> EncodingSerializer<S, E>
362        where
363            S: Serializer,
364            E: BytesEncoding,
365        {
366            pub(super) fn new(s: S) -> Self {
367                Self {
368                    s,
369                    phantom: PhantomData,
370                }
371            }
372        }
373
374        impl<S, E> Serializer for EncodingSerializer<S, E>
375        where
376            S: Serializer,
377            E: BytesEncoding,
378        {
379            type Ok = S::Ok;
380            type Error = S::Error;
381            type SerializeSeq = serde::ser::Impossible<S::Ok, Self::Error>;
382            type SerializeTuple = encoding_serializer::SerializeTuple<S, E>;
383            type SerializeTupleStruct = serde::ser::Impossible<S::Ok, Self::Error>;
384            type SerializeTupleVariant = serde::ser::Impossible<S::Ok, Self::Error>;
385            type SerializeMap = serde::ser::Impossible<S::Ok, Self::Error>;
386            type SerializeStruct = serde::ser::Impossible<S::Ok, Self::Error>;
387            type SerializeStructVariant = serde::ser::Impossible<S::Ok, Self::Error>;
388
389            fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
390                let encoded_len: usize = match E::encoded_len(v) {
391                    Ok(encoded_len) => encoded_len,
392                    Err(err) => return Err(S::Error::custom(err)),
393                };
394
395                let mut string = unsafe { String::from_utf8_unchecked(vec![0; encoded_len]) };
396                // SAFETY: only zeroes is valid utf8
397
398                let substr: &str = match E::encode(v, &mut string) {
399                    Ok(substr) => substr,
400                    Err(err) => return Err(S::Error::custom(err)),
401                };
402
403                self.s.serialize_str(substr)
404            }
405
406            fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple, Self::Error> {
407                Ok(encoding_serializer::SerializeTuple::<S, E>::new(self, len))
408            }
409
410            serialize_primitives! {
411                serialize_bool: bool,
412                serialize_i8: i8,
413                serialize_i16: i16,
414                serialize_i32: i32,
415                serialize_i64: i64,
416                serialize_i128: i128,
417                serialize_u8: u8,
418                serialize_u16: u16,
419                serialize_u32: u32,
420                serialize_u64: u64,
421                serialize_u128: u128,
422                serialize_f32: f32,
423                serialize_f64: f64,
424                serialize_char: char,
425                serialize_str: &str,
426                serialize_unit_struct: &'static str,
427            }
428
429            fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
430                expected_bytes!("none")
431            }
432
433            fn serialize_some<T>(self, _value: &T) -> Result<Self::Ok, Self::Error>
434            where
435                T: Serialize + ?Sized,
436            {
437                expected_bytes!("some")
438            }
439
440            fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
441                expected_bytes!("unit")
442            }
443
444            fn serialize_unit_variant(
445                self,
446                _name: &'static str,
447                _variant_index: u32,
448                _variant: &'static str,
449            ) -> Result<Self::Ok, Self::Error> {
450                expected_bytes!("unit variant")
451            }
452
453            fn serialize_newtype_struct<T>(
454                self,
455                _name: &'static str,
456                _value: &T,
457            ) -> Result<Self::Ok, Self::Error>
458            where
459                T: Serialize + ?Sized,
460            {
461                expected_bytes!("newtype struct")
462            }
463
464            fn serialize_newtype_variant<T>(
465                self,
466                _name: &'static str,
467                _variant_index: u32,
468                _variant: &'static str,
469                _value: &T,
470            ) -> Result<Self::Ok, Self::Error>
471            where
472                T: Serialize + ?Sized,
473            {
474                expected_bytes!("newtype variant")
475            }
476
477            fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
478                expected_bytes!("seq")
479            }
480
481            fn serialize_tuple_struct(
482                self,
483                _name: &'static str,
484                _len: usize,
485            ) -> Result<Self::SerializeTupleStruct, Self::Error> {
486                expected_bytes!("tuple struct")
487            }
488
489            fn serialize_tuple_variant(
490                self,
491                _name: &'static str,
492                _variant_index: u32,
493                _variant: &'static str,
494                _len: usize,
495            ) -> Result<Self::SerializeTupleVariant, Self::Error> {
496                expected_bytes!("tuple variant")
497            }
498
499            fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
500                expected_bytes!("map")
501            }
502
503            fn serialize_struct(
504                self,
505                _name: &'static str,
506                _len: usize,
507            ) -> Result<Self::SerializeStruct, Self::Error> {
508                expected_bytes!("struct")
509            }
510
511            fn serialize_struct_variant(
512                self,
513                _name: &'static str,
514                _variant_index: u32,
515                _variant: &'static str,
516                _len: usize,
517            ) -> Result<Self::SerializeStructVariant, Self::Error> {
518                expected_bytes!("struct variant")
519            }
520        }
521
522        #[derive(thiserror::Error, Debug)]
523        #[error(
524            "to use a bytes encoding (like base64) for the serialization of a type, that type must serialize to bytes, but got {got}"
525        )]
526        struct ExpectedBytesError {
527            got: &'static str,
528        }
529
530        pub(super) struct SerializeTuple<S, E> {
531            inner: Vec<u8>,
532            encoding_serializer: EncodingSerializer<S, E>,
533            expected_len: usize,
534        }
535
536        impl<S, E> SerializeTuple<S, E> {
537            fn new(encoding_serializer: EncodingSerializer<S, E>, expected_len: usize) -> Self
538            where
539                S: Serializer,
540                E: BytesEncoding,
541            {
542                Self {
543                    encoding_serializer,
544                    expected_len,
545                    inner: Vec::<u8>::with_capacity(expected_len),
546                }
547            }
548        }
549
550        impl<S: Serializer, E: BytesEncoding> serde::ser::SerializeTuple for SerializeTuple<S, E> {
551            type Ok = S::Ok;
552            type Error = S::Error;
553
554            fn serialize_element<T: Serialize + ?Sized>(
555                &mut self,
556                value: &T,
557            ) -> Result<(), Self::Error> {
558                if self.inner.len() == self.expected_len {
559                    return Err(Self::Error::custom(
560                        "improper use of serializer: serializing more tuple elements than announced",
561                    ));
562                }
563
564                // TODO, maybe: proper `ByteSerializer` implementation to replace this hack
565                let byte =
566                    u8::deserialize(value.serialize(serde_json::value::Serializer).map_err(
567                        |err| {
568                            Self::Error::custom(format!(
569                                "failed to serialize to byte (hackingly via serde_json): {err}"
570                            ))
571                        },
572                    )?)
573                    .map_err(|err| {
574                        Self::Error::custom(format!(
575                            "failed to serialize to byte (hackingly via serde_json): {err}"
576                        ))
577                    })?;
578                self.inner.push(byte);
579
580                Ok(())
581            }
582
583            fn end(self) -> Result<Self::Ok, Self::Error> {
584                if self.inner.len() != self.expected_len {
585                    return Err(Self::Error::custom(
586                        "improper use of serializer: serialized less tuple elements than announced",
587                    ));
588                }
589
590                self.encoding_serializer.serialize_bytes(&self.inner)
591            }
592        }
593    }
594
595    use encoding_serializer::EncodingSerializer;
596
597    impl<T, O: Options> Serialize for BytesWrapper<T, O>
598    where
599        T: Serialize,
600        O::Encoding: BytesEncoding,
601    {
602        fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
603            self.inner
604                .serialize(EncodingSerializer::<_, O::Encoding>::new(s))
605        }
606    }
607
608    impl<T, O: Options> core::str::FromStr for BytesWrapper<T, O>
609    where
610        T: for<'de> Deserialize<'de>,
611        O::Encoding: BytesEncoding,
612    {
613        type Err = serde::de::value::Error;
614
615        fn from_str(s: &str) -> Result<Self, Self::Err> {
616            Self::deserialize(s.into_deserializer())
617        }
618    }
619
620    impl<T, O: Options> std::fmt::Display for BytesWrapper<T, O>
621    where
622        T: Serialize,
623        O::Encoding: BytesEncoding,
624    {
625        fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
626            self.serialize(f)
627        }
628    }
629
630    /// Extracts a `O::Value` from the given [Deserializer] by extracting a string,
631    /// decoding this to a byte array according to [BytesEncoding] `O::Encoding`,
632    /// and finally passing this byte array to the [Deserialize] implementation of `T`.
633    struct EncodedBytesVisitor<T, O> {
634        phantom_t: PhantomData<T>,
635        phantom_o: PhantomData<O>,
636    }
637
638    impl<T, O: Options> EncodedBytesVisitor<T, O>
639    where
640        T: serde::de::DeserializeOwned,
641        O::Encoding: BytesEncoding,
642    {
643        fn new() -> Self {
644            Self {
645                phantom_t: PhantomData,
646                phantom_o: PhantomData,
647            }
648        }
649    }
650
651    impl<T, O: Options> serde::de::Visitor<'_> for EncodedBytesVisitor<T, O>
652    where
653        T: serde::de::DeserializeOwned,
654        O::Encoding: BytesEncoding,
655    {
656        type Value = T;
657
658        fn expecting(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
659            write!(f, "str")
660        }
661
662        fn visit_str<Error: serde::de::Error>(self, v: &str) -> Result<Self::Value, Error> {
663            let decoded_len: usize = match O::Encoding::decoded_len(v) {
664                Ok(decoded_len) => decoded_len,
665                Err(err) => return Err(Error::custom(err)), // TODO: better err?
666            };
667
668            let mut buf = vec![0; decoded_len];
669
670            let slice: &[u8] = match O::Encoding::decode(v, &mut buf) {
671                Ok(slice) => slice,
672                Err(err) => return Err(Error::custom(err)), // TODO: better err?
673            };
674
675            let slice_len = slice.len();
676            let slice_ptr = slice.as_ptr();
677
678            // truncate buf to the size used by decode, but first check that slice
679            // is indeed a slice into buf starting at index 0
680            assert_eq!(buf.as_ptr(), slice_ptr);
681            buf.truncate(slice_len);
682
683            match O::VISITOR_TYPE {
684                VisitorType::OwnedByteArray => T::deserialize(ByteBufDeserializer::new(buf)),
685                VisitorType::BorrowedByteArray => {
686                    T::deserialize(serde::de::value::BorrowedBytesDeserializer::new(&buf))
687                }
688                VisitorType::TransientByteArray => {
689                    T::deserialize(serde::de::value::BytesDeserializer::new(&buf))
690                }
691                VisitorType::ByteSequence => {
692                    T::deserialize(serde::de::value::SeqDeserializer::new(buf.into_iter()))
693                }
694            }
695        }
696    }
697
698    /// A [Deserializer] owning a `Vec<u8>` that always calls [serde::de::Visitor::visit_byte_buf].
699    #[derive(Clone)]
700    pub struct ByteBufDeserializer<E> {
701        value: Vec<u8>,
702        marker: PhantomData<E>,
703    }
704
705    impl<E> ByteBufDeserializer<E> {
706        pub fn new(value: Vec<u8>) -> Self {
707            Self {
708                value,
709                marker: PhantomData,
710            }
711        }
712    }
713
714    impl<'de, E> serde::de::Deserializer<'de> for ByteBufDeserializer<E>
715    where
716        E: serde::de::Error,
717    {
718        type Error = E;
719
720        fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
721        where
722            V: serde::de::Visitor<'de>,
723        {
724            visitor.visit_byte_buf(self.value)
725        }
726
727        serde::forward_to_deserialize_any! {
728            bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
729            bytes byte_buf option unit unit_struct newtype_struct seq tuple
730            tuple_struct map struct identifier ignored_any enum
731        }
732    }
733
734    impl<E> core::fmt::Debug for ByteBufDeserializer<E> {
735        fn fmt(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
736            formatter
737                .debug_struct("ByteBufDeserializer")
738                .field("value", &self.value)
739                .finish()
740        }
741    }
742
743    impl<'de, T, O: Options> Deserialize<'de> for BytesWrapper<T, O>
744    where
745        T: for<'de2> Deserialize<'de2>,
746        O::Encoding: BytesEncoding,
747    {
748        fn deserialize<D>(d: D) -> Result<Self, D::Error>
749        where
750            D: Deserializer<'de>,
751        {
752            Ok(d.deserialize_str(EncodedBytesVisitor::<T, O>::new())?
753                .into())
754        }
755    }
756
757    #[cfg(test)]
758    mod tests {
759        use super::*;
760
761        #[test]
762        fn serialize_bytes_wrapper() {
763            assert_eq!(
764                &serde_json::to_string(&B64UU::<_>::from(serde_bytes::ByteBuf::from([0; 32])))
765                    .unwrap(),
766                "\"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\""
767            );
768        }
769
770        #[test]
771        fn byte_array_deserialization() {
772            assert_eq!(
773                ByteArray::<4>::deserialize(serde::de::value::BytesDeserializer::<
774                    serde::de::value::Error,
775                >::new(b"test"))
776                .unwrap()
777                .inner,
778                *b"test"
779            );
780
781            assert_eq!(
782                ByteArray::<4>::deserialize(serde::de::value::BorrowedBytesDeserializer::<
783                    serde::de::value::Error,
784                >::new(b"test"))
785                .unwrap()
786                .inner,
787                *b"test"
788            );
789
790            assert_eq!(
791                ByteArray::<4>::deserialize(ByteBufDeserializer::<serde::de::value::Error>::new(
792                    b"test".to_vec()
793                ))
794                .unwrap()
795                .inner,
796                *b"test"
797            );
798        }
799
800        #[test]
801        fn ed25519_dalek_bug() {
802            let bytes = base16ct::lower::decode_vec(
803                "66b1419fae979516fb3807dda1b05026b2570a7ab2190254e524af4f0934ddd2",
804            )
805            .unwrap();
806
807            let d =
808                serde::de::value::BorrowedBytesDeserializer::<serde::de::value::Error>::new(&bytes);
809            ed25519_dalek::VerifyingKey::deserialize(d).unwrap(); // works
810
811            let d = serde::de::value::BytesDeserializer::<serde::de::value::Error>::new(&bytes);
812            ed25519_dalek::VerifyingKey::deserialize(d).unwrap();
813            // This *used to* err, but now works after
814            // https://github.com/dalek-cryptography/curve25519-dalek/pull/602 is released.
815
816            let d = serde::de::value::SeqDeserializer::<_, serde::de::value::Error>::new(
817                [0u8; 32].into_iter(),
818            );
819            curve25519_dalek::scalar::Scalar::deserialize(d).unwrap(); // works
820
821            let d = serde::de::value::BytesDeserializer::<serde::de::value::Error>::new(&[0u8; 32]);
822            curve25519_dalek::scalar::Scalar::deserialize(d).unwrap_err();
823            // This errs, but since `Scalar`s, unlike `{Signing,Verifying}Key`s, are serialized as
824            // byte sequences instead of byte arrays, this will probably not change anytime soon.
825        }
826    }
827}
828
829pub use bytes_wrapper::BytesWrapper;
830
831/// Wrapper around `[u8, N]` that (de)serializes using the byte buffer (instead of sequence) data type.
832#[derive(Copy, Debug, Clone, PartialEq, Eq, Hash)]
833pub struct ByteArray<const N: usize> {
834    inner: [u8; N],
835}
836
837impl<const N: usize> std::ops::Deref for ByteArray<N> {
838    type Target = [u8; N];
839
840    fn deref(&self) -> &Self::Target {
841        &self.inner
842    }
843}
844
845impl<const N: usize> From<[u8; N]> for ByteArray<N> {
846    fn from(inner: [u8; N]) -> Self {
847        Self { inner }
848    }
849}
850
851impl<const N: usize> From<ByteArray<N>> for [u8; N] {
852    fn from(val: ByteArray<N>) -> Self {
853        val.inner
854    }
855}
856
857impl<const N: usize> Serialize for ByteArray<N> {
858    fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
859        s.serialize_bytes(&self.inner)
860    }
861}
862
863/// Extracts a [ByteArray] from a [serde::Deserializer].
864struct ByteArrayVisitor<const N: usize> {}
865
866impl<const N: usize> serde::de::Visitor<'_> for ByteArrayVisitor<N> {
867    type Value = [u8; N];
868
869    fn expecting(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
870        write!(f, "a byte array of length {N}")
871    }
872
873    fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
874    where
875        E: serde::de::Error,
876    {
877        <[u8; N]>::try_from(v).map_err(|v| E::invalid_length(v.len(), &self))
878    }
879
880    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
881    where
882        E: serde::de::Error,
883    {
884        <[u8; N]>::try_from(v).map_err(|_| E::invalid_length(v.len(), &self))
885    }
886}
887
888impl<'de, const N: usize> Deserialize<'de> for ByteArray<N> {
889    fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
890        Ok(d.deserialize_byte_buf(ByteArrayVisitor::<N> {})?.into())
891    }
892}