Skip to main content

pubhubs/common/
elgamal.rs

1//! The ElGamal cryptosystem, as used in PEP
2
3use curve25519_dalek::{
4    constants::RISTRETTO_BASEPOINT_TABLE as B,
5    ristretto::{CompressedRistretto, RistrettoPoint},
6    scalar::Scalar,
7};
8
9/// ElGamal ciphertext - the result of [`PublicKey::encrypt`].
10///
11/// The associated public key is remembered to allow rerandomization, but this public key is
12/// not authenticated in any way.  This means that anyone intercepting a triple may
13/// modify the public key without detection (but this does not cause the
14/// triple to be decryptable to the same plaintext by another public key.)
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct Triple {
17    /// Ephemeral key
18    ek: RistrettoPoint,
19    /// Ciphertext,
20    ct: RistrettoPoint,
21    /// Public key
22    pk: RistrettoPoint,
23}
24
25impl Triple {
26    /// Decrypts the triple using the given private key `sk`.  If the triple was encrypted
27    /// for a different private key, the result is a random point.
28    pub fn decrypt(self, sk: &PrivateKey) -> RistrettoPoint {
29        self.ct - sk.scalar * self.ek
30    }
31
32    /// Decrypts the triple using the given private key `sk` if the triple claims to be encrypted
33    /// for the associated public key;  returns `None` otherwise.
34    ///
35    /// **Warning** This function can't check whether the triple's public key `pk` has been
36    /// tampered with.  
37    ///
38    /// While tampering cannot be prevented, the plaintext of a triple with spoofed `pk` can be
39    /// garbled, using [Self::rerandomize].
40    ///
41    pub fn decrypt_and_check_pk(self, sk: &PrivateKey) -> Option<RistrettoPoint> {
42        if self.pk == B * &sk.scalar {
43            Some(self.decrypt(sk))
44        } else {
45            None
46        }
47    }
48
49    /// Changes the public key of this triple, likely resulting in garbage down the road.
50    ///
51    /// Used for demonstration purposes.
52    pub fn spoof_pk(self, pk: PublicKey) -> Triple {
53        Triple {
54            ek: self.ek,
55            ct: self.ct,
56            pk: pk.point,
57        }
58    }
59
60    /// Changes the appearance of the ciphertext, but leaves the plaintext and the target
61    /// public key unaltered.  If the public key was spoofed, the plaintext is garbled.
62    /// ```
63    /// use pubhubs::common::elgamal::{PrivateKey, random_point, random_scalar};
64    /// use curve25519_dalek::{
65    ///     ristretto::RistrettoPoint,
66    ///     constants::RISTRETTO_BASEPOINT_TABLE as B,
67    /// };
68    ///
69    /// let M = random_point();
70    /// let sk = PrivateKey::random();
71    /// let pk = sk.public_key();
72    ///
73    /// let r1 = random_scalar();
74    /// let r2 = random_scalar();
75    ///
76    /// // Rerandomization leaves the plaintext unchanged:
77    /// let trip = pk.encrypt_with_random(r1, M).rerandomize_with_random(r2);
78    /// assert_eq!(trip, pk.encrypt_with_random(r1+r2,M));
79    ///
80    /// // But if the public key was spoofed, the plaintext is garbled:
81    /// let sk2 = PrivateKey::random();
82    /// let pk2 = sk2.public_key().clone();
83    /// let trip = pk.encrypt_with_random(r1, M).spoof_pk(pk2).rerandomize_with_random(r2);
84    ///
85    /// assert_eq!(trip.clone().decrypt_and_check_pk(&sk2),
86    ///     Some(M + B * &(r1 * (sk.as_scalar()-sk2.as_scalar()))));
87    ///
88    /// // Indeed, if sk =/= sk2, then  r1(sk - sk2)B will be some random unknowable Ristretto
89    /// // point, because r1 should be a random scalar that has been thrown away.
90    /// ```
91    pub fn rerandomize(self) -> Triple {
92        self.rerandomize_with_random(random_scalar())
93    }
94
95    /// Like [Self::rerandomize], but you can specify the random scalar used -
96    /// which you shouldn't except to make deterministic tests.
97    pub fn rerandomize_with_random(self, r: Scalar) -> Triple {
98        Triple {
99            ek: self.ek + &r * B,
100            ct: self.ct + r * self.pk,
101            pk: self.pk,
102        }
103    }
104
105    /// Like [rsk] but taking the parameters `s` and `k` thusly: `rsk_with_s(s).and_k(k)`.
106    pub fn rsk_with_s(self, s: &Scalar) -> rsk::WithS<'_> {
107        rsk::WithS { t: self, s }
108    }
109
110    /// Changes the given ciphertext according to the `params` provided:
111    ///
112    ///  - Multiplies the underlying plaintext by `params.s()`;
113    ///  - Multiplies the target public/private key by `params.k()`;
114    ///  - Rerandomizes the ciphertext using the scalar `params.r()`.
115    ///    
116    ///    If the public key `self.pk` was spoofed, the resulting plaintext is garbled,
117    ///    provided the scalar `params.r()` was random.
118    ///
119    /// If you only need to specify `s` and `k`, use `triple.rsk_with_s(s).and_k(k)` instead.
120    pub fn rsk(self, params: impl rsk::Params) -> Triple {
121        let r: Scalar = params.r();
122        let kpk = self.pk * params.k();
123
124        Triple {
125            ek: params.s_over_k() * self.ek + &r * B,
126            ct: params.s() * self.ct + r * kpk,
127            pk: kpk,
128        }
129    }
130}
131
132/// Utilities for [Triple::rsk]
133pub mod rsk {
134    use super::*;
135
136    /// Implementation of the [Params] trait given the parameters `s` and `k`.
137    pub struct SAndK<'s, 'k> {
138        s: &'s Scalar,
139        k: &'k Scalar,
140    }
141
142    impl Params for SAndK<'_, '_> {
143        fn s(&self) -> &Scalar {
144            self.s
145        }
146
147        fn k(&self) -> &Scalar {
148            self.k
149        }
150    }
151
152    /// The result of [Triple::rsk_with_s]. You should call [WithS::and_k] on it.
153    pub struct WithS<'a> {
154        pub(crate) t: Triple,
155        pub(crate) s: &'a Scalar,
156    }
157
158    impl WithS<'_> {
159        pub fn and_k(self, k: &Scalar) -> Triple {
160            self.t.rsk(SAndK { s: self.s, k })
161        }
162    }
163
164    /// Utilities for the [Triple::rsk] operation.
165    pub trait Params {
166        /// Multiply the encrypted plaintext ristretto point by this scalar.
167        fn s(&self) -> &Scalar;
168
169        /// Multiply the target public/private key by this scalar.
170        fn k(&self) -> &Scalar;
171
172        /// Returns `1/k`.
173        fn k_inv(&self) -> Scalar {
174            self.k().invert()
175        }
176
177        /// Returns `s/k`.
178        fn s_over_k(&self) -> Scalar {
179            self.s() * self.k_inv()
180        }
181
182        /// Returns the scalar used for rerandomisation.
183        ///
184        /// **Warning:** only override this method for the purpose of making deterministic test.
185        fn r(&self) -> Scalar {
186            random_scalar()
187        }
188    }
189}
190
191macro_rules! osrng {
192    () => {
193        &mut aead::OsRng
194    };
195}
196
197/// Returns a random Ristretto point, mainly for examples.
198///
199/// If you're immediately encrypting this point, consider
200/// using [PublicKey::encrypt_random] instead.
201pub fn random_point() -> RistrettoPoint {
202    RistrettoPoint::random(osrng!())
203}
204
205/// Returns a random scalar, mainly for examples.
206pub fn random_scalar() -> Scalar {
207    Scalar::random(osrng!())
208}
209
210/// Private key - load using [`PrivateKey::from_hex`] or generate with [`PrivateKey::random`].
211///
212/// Caches the associated [`PublicKey`], which means that loading a [`PrivateKey`] involves a base
213/// point multiplication.
214#[derive(Clone, PartialEq, Eq, Debug)]
215pub struct PrivateKey {
216    /// underlying scalar
217    scalar: Scalar,
218
219    /// associated public key, stored for efficiency
220    public_key: PublicKey,
221}
222
223impl PrivateKey {
224    /// Returns reference to underlying scalar.
225    pub fn as_scalar(&self) -> &Scalar {
226        &self.scalar
227    }
228
229    pub fn random() -> Self {
230        random_scalar().into()
231    }
232
233    pub fn public_key(&self) -> &PublicKey {
234        &self.public_key
235    }
236
237    /// Computes the [PublicKey] associated with the product of two [PrivateKey]s given only one
238    /// private key.
239    pub fn scale(&self, pk: &PublicKey) -> PublicKey {
240        (self.scalar * pk.point).into()
241    }
242
243    /// Creates a Diffie-Hellman-type shared secret between this [`PrivateKey`] and the [`PublicKey`].
244    pub fn shared_secret(&self, pk: &PublicKey) -> SharedSecret {
245        SharedSecret {
246            inner: self.scale(pk).to_bytes(),
247        }
248    }
249}
250
251impl From<Scalar> for PrivateKey {
252    fn from(scalar: Scalar) -> Self {
253        PrivateKey {
254            scalar,
255            public_key: (&scalar * B).into(),
256        }
257    }
258}
259
260/// Public key - obtained using [`PublicKey::from_hex`] or [`PrivateKey::public_key`].
261#[derive(Clone, PartialEq, Eq, Debug)]
262pub struct PublicKey {
263    point: RistrettoPoint,
264    compressed: CompressedRistretto,
265}
266
267impl AsRef<[u8]> for PublicKey {
268    /// Returns a reference to the compressed encoding of this public key
269    fn as_ref(&self) -> &[u8] {
270        self.compressed.as_bytes().as_slice()
271    }
272}
273
274impl PublicKey {
275    /// Turns a 64 digit hex string into a [`PublicKey`].
276    ///
277    /// Returns `None` when the hex-encoding is invalid or when the hex-encoding does not encode a
278    /// valid Ristretto point.
279    pub fn from_hex(hexstr: &str) -> Option<Self> {
280        CompressedRistretto::from_hex(hexstr)?.try_into().ok()
281    }
282
283    /// Encrypts the given `plaintext` for this public key.
284    /// If the plaintext is a random point, consider using [Self::encrypt_random].
285    pub fn encrypt(&self, plaintext: RistrettoPoint) -> Triple {
286        self.encrypt_with_random(random_scalar(), plaintext)
287    }
288
289    /// Like [`Self::encrypt`], but you can specify the random scalar used - which you shouldn't
290    /// except to make deterministic tests.
291    pub fn encrypt_with_random(&self, r: Scalar, plaintext: RistrettoPoint) -> Triple {
292        Triple {
293            ek: &r * B,
294            ct: plaintext + r * self.point,
295            pk: self.point,
296        }
297    }
298
299    /// Effectively encrypts a random plaintext for this public key.
300    ///
301    /// Instead of picking random Ristretto point M and random scalar r and computing
302    ///   `(rB, r * pk + M, self)`
303    /// we pick Ristretto points ek and ct randomly and return
304    ///   `(ek, ct, sekf)`.
305    /// since this is more efficient, and yields the same distribution.
306    pub fn encrypt_random(&self) -> Triple {
307        Triple {
308            ek: random_point(),
309            ct: random_point(),
310            pk: self.point,
311        }
312    }
313}
314
315impl From<RistrettoPoint> for PublicKey {
316    fn from(point: RistrettoPoint) -> Self {
317        Self {
318            point,
319            compressed: point.compress(),
320        }
321    }
322}
323
324impl TryFrom<CompressedRistretto> for PublicKey {
325    type Error = ();
326
327    fn try_from(compressed: CompressedRistretto) -> Result<Self, Self::Error> {
328        Ok(Self {
329            point: compressed.decompress().ok_or(())?,
330            compressed,
331        })
332    }
333}
334
335/// Adds encoding and decoding methods to [`PrivateKey`], [`PublicKey`], [`Triple`], [`Scalar`]
336/// and [`RistrettoPoint`] which can all be represented as `[u8; N]`s for some `N`.  
337///
338/// Not all arrays of the form `[u8; N]` may be a valid representation of the type of object in question, though.
339pub trait Encoding<const N: usize>
340where
341    Self: Sized,
342{
343    /// Decodes `Some(object)` from `bytes` if `bytes` encodes some `object` of type `Self`;
344    /// otherwise returns `None`.
345    fn from_bytes(bytes: [u8; N]) -> Option<Self>;
346
347    /// Encodes `self` as `[u8; N]`.
348    fn to_bytes(&self) -> [u8; N];
349
350    /// Like [Self::from_bytes], but reads `[u8; N]` from `slice`.  Returns `None` if `slice.len()!=N`
351    /// or when the slice is not a valid encoding.
352    fn from_slice(slice: &[u8]) -> Option<Self> {
353        if slice.len() != N {
354            return None;
355        }
356
357        let mut buf = [0u8; N];
358        buf.copy_from_slice(slice);
359
360        Self::from_bytes(buf)
361    }
362
363    /// Copies the encoding of `self` into `slice`.  Returns `None` when `slice.len()!=N`.
364    fn copy_to_slice(&self, slice: &mut [u8]) -> Option<()> {
365        if slice.len() != N {
366            return None;
367        }
368
369        slice.copy_from_slice(&self.to_bytes());
370
371        Some(())
372    }
373
374    /// Like [Self::from_bytes], but reads the `[u8; N]` from the 2*N-digit hex string `hex`.
375    /// The case of the hex digits is ignored.
376    fn from_hex(hex: &str) -> Option<Self> {
377        let hex: &[u8] = hex.as_bytes();
378
379        if hex.len() != 2 * N {
380            return None;
381        }
382
383        let mut buf = [0u8; N];
384
385        base16ct::mixed::decode(hex, &mut buf).ok()?;
386        Self::from_bytes(buf)
387    }
388
389    /// Returns the `2*N`-digit lower-case hex representation of `self`.
390    fn to_hex(&self) -> String {
391        base16ct::lower::encode_string(&self.to_bytes())
392    }
393
394    /// Loads object from the `N`-byte buffer pointed to by `ptr`.
395    ///
396    /// # Safety
397    /// The caller must make sure that `ptr` is properly alligned,
398    /// the `N`-byte buffer is readable, and isn't modified for the duration of the call.
399    ///
400    /// See the 'Safety' section of [core::slice::from_raw_parts] for more details.
401    unsafe fn from_ptr(ptr: *const u8) -> Option<Self> {
402        Self::from_slice(unsafe { core::slice::from_raw_parts(ptr, N) })
403    }
404
405    /// Writes the `N`-byte representation of this object to the memory location `ptr`.
406    ///
407    /// # Safety
408    /// The caller must make sure that `ptr` is properly alligned,
409    /// the `N`-byte buffer is writable, and isn't modified for the duration of the call.
410    ///
411    /// See the 'Safety' section of [core::slice::from_raw_parts_mut] for more details.
412    unsafe fn copy_to_ptr(self, ptr: *mut u8) {
413        self.copy_to_slice(unsafe { core::slice::from_raw_parts_mut(ptr, N) })
414            .unwrap()
415        // Note: `copy_to_slice` only fails when the provided slice has the incorrect size (not `N`)
416        // which is not the case here.
417    }
418}
419
420impl Encoding<32> for Scalar {
421    fn from_bytes(bytes: [u8; 32]) -> Option<Scalar> {
422        Scalar::from_canonical_bytes(bytes).into()
423    }
424
425    fn to_bytes(&self) -> [u8; 32] {
426        Scalar::to_bytes(self)
427    }
428}
429
430impl Encoding<32> for CompressedRistretto {
431    fn from_bytes(bytes: [u8; 32]) -> Option<CompressedRistretto> {
432        Some(CompressedRistretto(bytes))
433    }
434
435    fn to_bytes(&self) -> [u8; 32] {
436        self.to_bytes()
437    }
438}
439
440impl Encoding<32> for RistrettoPoint {
441    fn from_bytes(bytes: [u8; 32]) -> Option<RistrettoPoint> {
442        CompressedRistretto(bytes).decompress()
443    }
444
445    fn to_bytes(&self) -> [u8; 32] {
446        self.compress().to_bytes()
447    }
448}
449
450impl Encoding<32> for PrivateKey {
451    fn from_bytes(bytes: [u8; 32]) -> Option<PrivateKey> {
452        Scalar::from_bytes(bytes).map(PrivateKey::from)
453    }
454
455    fn to_bytes(&self) -> [u8; 32] {
456        self.scalar.to_bytes()
457    }
458}
459
460impl Encoding<32> for PublicKey {
461    fn from_bytes(bytes: [u8; 32]) -> Option<PublicKey> {
462        CompressedRistretto::from_bytes(bytes)?.try_into().ok()
463    }
464
465    fn to_bytes(&self) -> [u8; 32] {
466        self.compressed.to_bytes()
467    }
468}
469
470impl Encoding<96> for Triple {
471    fn from_bytes(bytes: [u8; 96]) -> Option<Triple> {
472        let ek: RistrettoPoint = RistrettoPoint::from_slice(&bytes[..32])?;
473        let ct: RistrettoPoint = RistrettoPoint::from_slice(&bytes[32..64])?;
474        let pk: RistrettoPoint = RistrettoPoint::from_slice(&bytes[64..])?;
475
476        Some(Triple { ek, ct, pk })
477    }
478
479    fn to_bytes(&self) -> [u8; 96] {
480        let mut result = [0u8; 96];
481
482        // Note: `copy_to_slice` only fails when the slice's size is not 32, which it won't below
483        self.ek.copy_to_slice(&mut result[..32]).unwrap();
484        self.ct.copy_to_slice(&mut result[32..64]).unwrap();
485        self.pk.copy_to_slice(&mut result[64..]).unwrap();
486
487        result
488    }
489}
490
491mod serde_impls {
492    use super::*;
493    use crate::misc::serde_ext;
494    use serde::de::Error as _;
495
496    /// Implements [`serde::Serialize`] and [`serde::Deserialize`] using [`serde_ext::ByteArray`] and hex
497    /// encoding
498    macro_rules! serde_impl {
499        { $type:ident, $n:literal } => {
500
501            impl<'de> serde::Deserialize<'de> for $type {
502                fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
503                    let byte_array : serde_ext::ByteArray<$n>  =
504                        serde_ext::bytes_wrapper::B16::<serde_ext::ByteArray<$n>>::deserialize(d)?.into_inner();
505                    $type::from_bytes(byte_array.into()).ok_or_else(|| D::Error::custom(concat!("invalid ", stringify!($type))))
506                }
507            }
508
509            impl<'de> serde::Serialize for $type {
510                fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
511                    let byte_array = serde_ext::ByteArray::<$n>::from(self.to_bytes());
512                    serde_ext::bytes_wrapper::B16::<serde_ext::ByteArray<$n>>::from(byte_array)
513                        .serialize(s)
514                }
515            }
516        }
517    }
518
519    serde_impl! { PrivateKey, 32 }
520    serde_impl! { PublicKey, 32 }
521    serde_impl! { Triple, 96 }
522}
523
524/// Shared secret created by combining a [`PrivateKey`] with a [`PublicKey`], which, although it is
525/// basically the encoding of a [`RistrettoPoint`], is given a separate interface to limit its
526/// usage.
527#[derive(Clone, Debug, zeroize::ZeroizeOnDrop)]
528pub struct SharedSecret {
529    inner: [u8; 32],
530}
531
532impl crate::common::secret::DigestibleSecret for SharedSecret {
533    fn as_bytes(&self) -> &[u8] {
534        &self.inner
535    }
536}
537
538impl crate::common::secret::DigestibleSecret for PrivateKey {
539    fn as_bytes(&self) -> &[u8] {
540        self.scalar.as_bytes().as_slice()
541    }
542}
543
544///// Application binary interface
545//pub mod abi {
546//    use super::*;
547//
548//    /// Decrypts the given `ciphertext` using the given `private_key` and stores the result in
549//    /// `plaintext`.
550//    ///
551//    ///   * `plaintext` - pointer to a writable 32-byte buffer
552//    ///   * `ciperhtext` - pointer to a 96-byte buffer holding the result of [Triple::to_bytes]
553//    ///   * `private_key` - pointer to a 32-byte buffer holding the result of [Scalar::to_bytes]
554//    ///
555//    /// # Safety
556//    /// The caller must make sure the pointers are aligned, point to valid memory regions,
557//    /// are readable, and plaintext is writable, and are not otherwise modified.
558//    ///
559//    /// For more details, see [core::slice::from_raw_parts] and [core::slice::from_raw_parts_mut].
560//    #[unsafe(no_mangle)]
561//    pub unsafe extern "C" fn decrypt(
562//        plaintext: *mut u8,
563//        ciphertext: *const u8,
564//        private_key: *const u8,
565//    ) -> DecryptResult {
566//        let pk = match unsafe { PrivateKey::from_ptr(private_key) } {
567//            Some(pk) => pk,
568//            None => return DecryptResult::InvalidPrivateKey,
569//        };
570//
571//        let ct = match unsafe { Triple::from_ptr(ciphertext) } {
572//            Some(ct) => ct,
573//            None => return DecryptResult::InvalidTriple,
574//        };
575//
576//        let pt = match ct.decrypt_and_check_pk(&pk) {
577//            Some(pt) => pt,
578//            None => return DecryptResult::WrongPublicKey,
579//        };
580//
581//        unsafe { pt.copy_to_ptr(plaintext) }
582//
583//        DecryptResult::Ok
584//    }
585//
586//    /// Result of [decrypt].
587//    #[repr(u8)]
588//    pub enum DecryptResult {
589//        Ok = 1,
590//        WrongPublicKey = 2,
591//        InvalidTriple = 3,
592//        InvalidPrivateKey = 4,
593//    }
594//}