|
| 1 | +use core::fmt; |
| 2 | +use core::marker::PhantomData; |
| 3 | + |
| 4 | +use serde::{ |
| 5 | + de::{Error, Visitor}, |
| 6 | + Serializer, |
| 7 | +}; |
| 8 | + |
| 9 | +#[cfg(feature = "alloc")] |
| 10 | +use ::{alloc::vec::Vec, serde::Serialize}; |
| 11 | + |
| 12 | +#[cfg(not(feature = "alloc"))] |
| 13 | +use serde::ser::Error as SerError; |
| 14 | + |
| 15 | +pub(crate) fn serialize_hex<S, T, const UPPERCASE: bool>( |
| 16 | + value: &T, |
| 17 | + serializer: S, |
| 18 | +) -> Result<S::Ok, S::Error> |
| 19 | +where |
| 20 | + S: Serializer, |
| 21 | + T: AsRef<[u8]>, |
| 22 | +{ |
| 23 | + #[cfg(feature = "alloc")] |
| 24 | + if UPPERCASE { |
| 25 | + return base16ct::upper::encode_string(value.as_ref()).serialize(serializer); |
| 26 | + } else { |
| 27 | + return base16ct::lower::encode_string(value.as_ref()).serialize(serializer); |
| 28 | + } |
| 29 | + #[cfg(not(feature = "alloc"))] |
| 30 | + { |
| 31 | + let _ = value; |
| 32 | + let _ = serializer; |
| 33 | + return Err(S::Error::custom( |
| 34 | + "serializer is human readable, which requires the `alloc` crate feature", |
| 35 | + )); |
| 36 | + } |
| 37 | +} |
| 38 | + |
| 39 | +pub(crate) fn serialize_hex_lower_or_bin<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error> |
| 40 | +where |
| 41 | + S: Serializer, |
| 42 | + T: AsRef<[u8]>, |
| 43 | +{ |
| 44 | + if serializer.is_human_readable() { |
| 45 | + serialize_hex::<_, _, false>(value, serializer) |
| 46 | + } else { |
| 47 | + serializer.serialize_bytes(value.as_ref()) |
| 48 | + } |
| 49 | +} |
| 50 | + |
| 51 | +/// Serialize the given type as upper case hex when using human-readable |
| 52 | +/// formats or binary if the format is binary. |
| 53 | +pub(crate) fn serialize_hex_upper_or_bin<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error> |
| 54 | +where |
| 55 | + S: Serializer, |
| 56 | + T: AsRef<[u8]>, |
| 57 | +{ |
| 58 | + if serializer.is_human_readable() { |
| 59 | + serialize_hex::<_, _, true>(value, serializer) |
| 60 | + } else { |
| 61 | + serializer.serialize_bytes(value.as_ref()) |
| 62 | + } |
| 63 | +} |
| 64 | + |
| 65 | +pub(crate) trait LengthCheck { |
| 66 | + fn length_check(buffer_length: usize, data_length: usize) -> bool; |
| 67 | + fn expecting( |
| 68 | + formatter: &mut fmt::Formatter<'_>, |
| 69 | + data_type: &str, |
| 70 | + data_length: usize, |
| 71 | + ) -> fmt::Result; |
| 72 | +} |
| 73 | + |
| 74 | +pub(crate) struct ExactLength; |
| 75 | + |
| 76 | +impl LengthCheck for ExactLength { |
| 77 | + fn length_check(buffer_length: usize, data_length: usize) -> bool { |
| 78 | + buffer_length == data_length |
| 79 | + } |
| 80 | + fn expecting( |
| 81 | + formatter: &mut fmt::Formatter<'_>, |
| 82 | + data_type: &str, |
| 83 | + data_length: usize, |
| 84 | + ) -> fmt::Result { |
| 85 | + write!(formatter, "{} of length {}", data_type, data_length) |
| 86 | + } |
| 87 | +} |
| 88 | + |
| 89 | +pub(crate) struct UpperBound; |
| 90 | + |
| 91 | +impl LengthCheck for UpperBound { |
| 92 | + fn length_check(buffer_length: usize, data_length: usize) -> bool { |
| 93 | + buffer_length >= data_length |
| 94 | + } |
| 95 | + fn expecting( |
| 96 | + formatter: &mut fmt::Formatter<'_>, |
| 97 | + data_type: &str, |
| 98 | + data_length: usize, |
| 99 | + ) -> fmt::Result { |
| 100 | + write!( |
| 101 | + formatter, |
| 102 | + "{} with a maximum length of {}", |
| 103 | + data_type, data_length |
| 104 | + ) |
| 105 | + } |
| 106 | +} |
| 107 | + |
| 108 | +pub(crate) struct StrIntoBufVisitor<'b, T: LengthCheck>(pub &'b mut [u8], pub PhantomData<T>); |
| 109 | + |
| 110 | +impl<'de, 'b, T: LengthCheck> Visitor<'de> for StrIntoBufVisitor<'b, T> { |
| 111 | + type Value = (); |
| 112 | + |
| 113 | + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 114 | + T::expecting(formatter, "a string", self.0.len() * 2) |
| 115 | + } |
| 116 | + |
| 117 | + fn visit_str<E>(self, v: &str) -> Result<Self::Value, E> |
| 118 | + where |
| 119 | + E: Error, |
| 120 | + { |
| 121 | + if !T::length_check(self.0.len() * 2, v.len()) { |
| 122 | + return Err(Error::invalid_length(v.len(), &self)); |
| 123 | + } |
| 124 | + // TODO: Map `base16ct::Error::InvalidLength` to `Error::invalid_length`. |
| 125 | + base16ct::mixed::decode(v, self.0) |
| 126 | + .map(|_| ()) |
| 127 | + .map_err(E::custom) |
| 128 | + } |
| 129 | +} |
| 130 | + |
| 131 | +#[cfg(feature = "alloc")] |
| 132 | +pub(crate) struct StrIntoVecVisitor; |
| 133 | + |
| 134 | +#[cfg(feature = "alloc")] |
| 135 | +impl<'de> Visitor<'de> for StrIntoVecVisitor { |
| 136 | + type Value = Vec<u8>; |
| 137 | + |
| 138 | + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 139 | + write!(formatter, "a string") |
| 140 | + } |
| 141 | + |
| 142 | + fn visit_str<E>(self, v: &str) -> Result<Self::Value, E> |
| 143 | + where |
| 144 | + E: Error, |
| 145 | + { |
| 146 | + base16ct::mixed::decode_vec(v).map_err(E::custom) |
| 147 | + } |
| 148 | +} |
| 149 | + |
| 150 | +pub(crate) struct SliceVisitor<'b, T: LengthCheck>(pub &'b mut [u8], pub PhantomData<T>); |
| 151 | + |
| 152 | +impl<'de, 'b, T: LengthCheck> Visitor<'de> for SliceVisitor<'b, T> { |
| 153 | + type Value = (); |
| 154 | + |
| 155 | + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 156 | + T::expecting(formatter, "an array", self.0.len()) |
| 157 | + } |
| 158 | + |
| 159 | + fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E> |
| 160 | + where |
| 161 | + E: Error, |
| 162 | + { |
| 163 | + // Workaround for |
| 164 | + // https://github.com/rust-lang/rfcs/blob/b1de05846d9bc5591d753f611ab8ee84a01fa500/text/2094-nll.md#problem-case-3-conditional-control-flow-across-functions |
| 165 | + if T::length_check(self.0.len(), v.len()) { |
| 166 | + let buffer = &mut self.0[..v.len()]; |
| 167 | + buffer.copy_from_slice(v); |
| 168 | + return Ok(()); |
| 169 | + } |
| 170 | + |
| 171 | + Err(E::invalid_length(v.len(), &self)) |
| 172 | + } |
| 173 | + |
| 174 | + #[cfg(feature = "alloc")] |
| 175 | + fn visit_byte_buf<E>(self, mut v: Vec<u8>) -> Result<Self::Value, E> |
| 176 | + where |
| 177 | + E: Error, |
| 178 | + { |
| 179 | + // Workaround for |
| 180 | + // https://github.com/rust-lang/rfcs/blob/b1de05846d9bc5591d753f611ab8ee84a01fa500/text/2094-nll.md#problem-case-3-conditional-control-flow-across-functions |
| 181 | + if T::length_check(self.0.len(), v.len()) { |
| 182 | + let buffer = &mut self.0[..v.len()]; |
| 183 | + buffer.swap_with_slice(&mut v); |
| 184 | + return Ok(()); |
| 185 | + } |
| 186 | + |
| 187 | + Err(E::invalid_length(v.len(), &self)) |
| 188 | + } |
| 189 | +} |
| 190 | + |
| 191 | +#[cfg(feature = "alloc")] |
| 192 | +pub(crate) struct VecVisitor; |
| 193 | + |
| 194 | +#[cfg(feature = "alloc")] |
| 195 | +impl<'de> Visitor<'de> for VecVisitor { |
| 196 | + type Value = Vec<u8>; |
| 197 | + |
| 198 | + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 199 | + write!(formatter, "a bytestring") |
| 200 | + } |
| 201 | + |
| 202 | + fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E> |
| 203 | + where |
| 204 | + E: Error, |
| 205 | + { |
| 206 | + Ok(v.into()) |
| 207 | + } |
| 208 | + |
| 209 | + fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E> |
| 210 | + where |
| 211 | + E: Error, |
| 212 | + { |
| 213 | + Ok(v) |
| 214 | + } |
| 215 | +} |
0 commit comments