Skip to content

Commit d8e666a

Browse files
committed
Move visitors to a shared module
1 parent d7cccf4 commit d8e666a

File tree

4 files changed

+233
-202
lines changed

4 files changed

+233
-202
lines changed

serdect/src/array.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Serialization primitives for arrays.
22
3-
// Unfortunately, we currently cannot assert generically that we are serializing
3+
// Unfortunately, we currently cannot tell `serde` in a uniform fashion that we are serializing
44
// a fixed-size byte array.
55
// See https://github.com/serde-rs/serde/issues/2120 for the discussion.
66
// Therefore we have to fall back to the slice methods,
@@ -13,7 +13,7 @@ use core::marker::PhantomData;
1313

1414
use serde::{Deserialize, Deserializer, Serialize, Serializer};
1515

16-
use crate::slice;
16+
use crate::common::{self, ExactLength, SliceVisitor, StrIntoBufVisitor};
1717

1818
#[cfg(feature = "zeroize")]
1919
use zeroize::Zeroize;
@@ -25,7 +25,7 @@ where
2525
S: Serializer,
2626
T: AsRef<[u8]>,
2727
{
28-
slice::serialize_hex_lower_or_bin(value, serializer)
28+
common::serialize_hex_lower_or_bin(value, serializer)
2929
}
3030

3131
/// Serialize the given type as upper case hex when using human-readable
@@ -35,7 +35,7 @@ where
3535
S: Serializer,
3636
T: AsRef<[u8]>,
3737
{
38-
slice::serialize_hex_upper_or_bin(value, serializer)
38+
common::serialize_hex_upper_or_bin(value, serializer)
3939
}
4040

4141
/// Deserialize from hex when using human-readable formats or binary if the
@@ -46,12 +46,9 @@ where
4646
D: Deserializer<'de>,
4747
{
4848
if deserializer.is_human_readable() {
49-
deserializer.deserialize_str(slice::StrVisitor::<slice::ExactLength>(buffer, PhantomData))
49+
deserializer.deserialize_str(StrIntoBufVisitor::<ExactLength>(buffer, PhantomData))
5050
} else {
51-
deserializer.deserialize_byte_buf(slice::SliceVisitor::<slice::ExactLength>(
52-
buffer,
53-
PhantomData,
54-
))
51+
deserializer.deserialize_byte_buf(SliceVisitor::<ExactLength>(buffer, PhantomData))
5552
}
5653
}
5754

serdect/src/common.rs

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
use core::fmt;
2+
use core::marker::PhantomData;
3+
4+
use serde::{
5+
de::{Error, Visitor},
6+
Serializer,
7+
};
8+
9+
#[cfg(feature = "alloc")]
10+
use ::{alloc::vec::Vec, serde::Serialize};
11+
12+
#[cfg(not(feature = "alloc"))]
13+
use serde::ser::Error as SerError;
14+
15+
pub(crate) fn serialize_hex<S, T, const UPPERCASE: bool>(
16+
value: &T,
17+
serializer: S,
18+
) -> Result<S::Ok, S::Error>
19+
where
20+
S: Serializer,
21+
T: AsRef<[u8]>,
22+
{
23+
#[cfg(feature = "alloc")]
24+
if UPPERCASE {
25+
return base16ct::upper::encode_string(value.as_ref()).serialize(serializer);
26+
} else {
27+
return base16ct::lower::encode_string(value.as_ref()).serialize(serializer);
28+
}
29+
#[cfg(not(feature = "alloc"))]
30+
{
31+
let _ = value;
32+
let _ = serializer;
33+
return Err(S::Error::custom(
34+
"serializer is human readable, which requires the `alloc` crate feature",
35+
));
36+
}
37+
}
38+
39+
pub(crate) fn serialize_hex_lower_or_bin<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
40+
where
41+
S: Serializer,
42+
T: AsRef<[u8]>,
43+
{
44+
if serializer.is_human_readable() {
45+
serialize_hex::<_, _, false>(value, serializer)
46+
} else {
47+
serializer.serialize_bytes(value.as_ref())
48+
}
49+
}
50+
51+
/// Serialize the given type as upper case hex when using human-readable
52+
/// formats or binary if the format is binary.
53+
pub(crate) fn serialize_hex_upper_or_bin<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
54+
where
55+
S: Serializer,
56+
T: AsRef<[u8]>,
57+
{
58+
if serializer.is_human_readable() {
59+
serialize_hex::<_, _, true>(value, serializer)
60+
} else {
61+
serializer.serialize_bytes(value.as_ref())
62+
}
63+
}
64+
65+
pub(crate) trait LengthCheck {
66+
fn length_check(buffer_length: usize, data_length: usize) -> bool;
67+
fn expecting(
68+
formatter: &mut fmt::Formatter<'_>,
69+
data_type: &str,
70+
data_length: usize,
71+
) -> fmt::Result;
72+
}
73+
74+
pub(crate) struct ExactLength;
75+
76+
impl LengthCheck for ExactLength {
77+
fn length_check(buffer_length: usize, data_length: usize) -> bool {
78+
buffer_length == data_length
79+
}
80+
fn expecting(
81+
formatter: &mut fmt::Formatter<'_>,
82+
data_type: &str,
83+
data_length: usize,
84+
) -> fmt::Result {
85+
write!(formatter, "{} of length {}", data_type, data_length)
86+
}
87+
}
88+
89+
pub(crate) struct UpperBound;
90+
91+
impl LengthCheck for UpperBound {
92+
fn length_check(buffer_length: usize, data_length: usize) -> bool {
93+
buffer_length >= data_length
94+
}
95+
fn expecting(
96+
formatter: &mut fmt::Formatter<'_>,
97+
data_type: &str,
98+
data_length: usize,
99+
) -> fmt::Result {
100+
write!(
101+
formatter,
102+
"{} with a maximum length of {}",
103+
data_type, data_length
104+
)
105+
}
106+
}
107+
108+
pub(crate) struct StrIntoBufVisitor<'b, T: LengthCheck>(pub &'b mut [u8], pub PhantomData<T>);
109+
110+
impl<'de, 'b, T: LengthCheck> Visitor<'de> for StrIntoBufVisitor<'b, T> {
111+
type Value = ();
112+
113+
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
114+
T::expecting(formatter, "a string", self.0.len() * 2)
115+
}
116+
117+
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
118+
where
119+
E: Error,
120+
{
121+
if !T::length_check(self.0.len() * 2, v.len()) {
122+
return Err(Error::invalid_length(v.len(), &self));
123+
}
124+
// TODO: Map `base16ct::Error::InvalidLength` to `Error::invalid_length`.
125+
base16ct::mixed::decode(v, self.0)
126+
.map(|_| ())
127+
.map_err(E::custom)
128+
}
129+
}
130+
131+
#[cfg(feature = "alloc")]
132+
pub(crate) struct StrIntoVecVisitor;
133+
134+
#[cfg(feature = "alloc")]
135+
impl<'de> Visitor<'de> for StrIntoVecVisitor {
136+
type Value = Vec<u8>;
137+
138+
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
139+
write!(formatter, "a string")
140+
}
141+
142+
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
143+
where
144+
E: Error,
145+
{
146+
base16ct::mixed::decode_vec(v).map_err(E::custom)
147+
}
148+
}
149+
150+
pub(crate) struct SliceVisitor<'b, T: LengthCheck>(pub &'b mut [u8], pub PhantomData<T>);
151+
152+
impl<'de, 'b, T: LengthCheck> Visitor<'de> for SliceVisitor<'b, T> {
153+
type Value = ();
154+
155+
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
156+
T::expecting(formatter, "an array", self.0.len())
157+
}
158+
159+
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
160+
where
161+
E: Error,
162+
{
163+
// Workaround for
164+
// https://github.com/rust-lang/rfcs/blob/b1de05846d9bc5591d753f611ab8ee84a01fa500/text/2094-nll.md#problem-case-3-conditional-control-flow-across-functions
165+
if T::length_check(self.0.len(), v.len()) {
166+
let buffer = &mut self.0[..v.len()];
167+
buffer.copy_from_slice(v);
168+
return Ok(());
169+
}
170+
171+
Err(E::invalid_length(v.len(), &self))
172+
}
173+
174+
#[cfg(feature = "alloc")]
175+
fn visit_byte_buf<E>(self, mut v: Vec<u8>) -> Result<Self::Value, E>
176+
where
177+
E: Error,
178+
{
179+
// Workaround for
180+
// https://github.com/rust-lang/rfcs/blob/b1de05846d9bc5591d753f611ab8ee84a01fa500/text/2094-nll.md#problem-case-3-conditional-control-flow-across-functions
181+
if T::length_check(self.0.len(), v.len()) {
182+
let buffer = &mut self.0[..v.len()];
183+
buffer.swap_with_slice(&mut v);
184+
return Ok(());
185+
}
186+
187+
Err(E::invalid_length(v.len(), &self))
188+
}
189+
}
190+
191+
#[cfg(feature = "alloc")]
192+
pub(crate) struct VecVisitor;
193+
194+
#[cfg(feature = "alloc")]
195+
impl<'de> Visitor<'de> for VecVisitor {
196+
type Value = Vec<u8>;
197+
198+
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
199+
write!(formatter, "a bytestring")
200+
}
201+
202+
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
203+
where
204+
E: Error,
205+
{
206+
Ok(v.into())
207+
}
208+
209+
fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
210+
where
211+
E: Error,
212+
{
213+
Ok(v)
214+
}
215+
}

serdect/src/lib.rs

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -131,35 +131,7 @@
131131
extern crate alloc;
132132

133133
pub mod array;
134+
mod common;
134135
pub mod slice;
135136

136137
pub use serde;
137-
138-
use serde::Serializer;
139-
140-
#[cfg(not(feature = "alloc"))]
141-
use serde::ser::Error;
142-
143-
#[cfg(feature = "alloc")]
144-
use serde::Serialize;
145-
146-
fn serialize_hex<S, T, const UPPERCASE: bool>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
147-
where
148-
S: Serializer,
149-
T: AsRef<[u8]>,
150-
{
151-
#[cfg(feature = "alloc")]
152-
if UPPERCASE {
153-
return base16ct::upper::encode_string(value.as_ref()).serialize(serializer);
154-
} else {
155-
return base16ct::lower::encode_string(value.as_ref()).serialize(serializer);
156-
}
157-
#[cfg(not(feature = "alloc"))]
158-
{
159-
let _ = value;
160-
let _ = serializer;
161-
return Err(S::Error::custom(
162-
"serializer is human readable, which requires the `alloc` crate feature",
163-
));
164-
}
165-
}

0 commit comments

Comments
 (0)