Skip to content

Commit 4181e81

Browse files
committed
Move visitors to a shared module
1 parent d7cccf4 commit 4181e81

File tree

4 files changed

+215
-182
lines changed

4 files changed

+215
-182
lines changed

serdect/src/array.rs

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Serialization primitives for arrays.
22
3-
// Unfortunately, we currently cannot assert generically that we are serializing
3+
// Unfortunately, we currently cannot tell `serde` in a uniform fashion that we are serializing
44
// a fixed-size byte array.
55
// See https://github.com/serde-rs/serde/issues/2120 for the discussion.
66
// Therefore we have to fall back to the slice methods,
@@ -9,11 +9,12 @@
99
// to be exactly equal to the size of the buffer during deserialization,
1010
// while for slices the buffer can be larger than the deserialized data.
1111

12+
use core::fmt;
1213
use core::marker::PhantomData;
1314

1415
use serde::{Deserialize, Deserializer, Serialize, Serializer};
1516

16-
use crate::slice;
17+
use crate::common::{self, LengthCheck, SliceVisitor, StrIntoBufVisitor};
1718

1819
#[cfg(feature = "zeroize")]
1920
use zeroize::Zeroize;
@@ -25,7 +26,7 @@ where
2526
S: Serializer,
2627
T: AsRef<[u8]>,
2728
{
28-
slice::serialize_hex_lower_or_bin(value, serializer)
29+
common::serialize_hex_lower_or_bin(value, serializer)
2930
}
3031

3132
/// Serialize the given type as upper case hex when using human-readable
@@ -35,7 +36,22 @@ where
3536
S: Serializer,
3637
T: AsRef<[u8]>,
3738
{
38-
slice::serialize_hex_upper_or_bin(value, serializer)
39+
common::serialize_hex_upper_or_bin(value, serializer)
40+
}
41+
42+
struct ExactLength;
43+
44+
impl LengthCheck for ExactLength {
45+
fn length_check(buffer_length: usize, data_length: usize) -> bool {
46+
buffer_length == data_length
47+
}
48+
fn expecting(
49+
formatter: &mut fmt::Formatter<'_>,
50+
data_type: &str,
51+
data_length: usize,
52+
) -> fmt::Result {
53+
write!(formatter, "{} of length {}", data_type, data_length)
54+
}
3955
}
4056

4157
/// Deserialize from hex when using human-readable formats or binary if the
@@ -46,12 +62,9 @@ where
4662
D: Deserializer<'de>,
4763
{
4864
if deserializer.is_human_readable() {
49-
deserializer.deserialize_str(slice::StrVisitor::<slice::ExactLength>(buffer, PhantomData))
65+
deserializer.deserialize_str(StrIntoBufVisitor::<ExactLength>(buffer, PhantomData))
5066
} else {
51-
deserializer.deserialize_byte_buf(slice::SliceVisitor::<slice::ExactLength>(
52-
buffer,
53-
PhantomData,
54-
))
67+
deserializer.deserialize_byte_buf(SliceVisitor::<ExactLength>(buffer, PhantomData))
5568
}
5669
}
5770

serdect/src/common.rs

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
use core::fmt;
2+
use core::marker::PhantomData;
3+
4+
use serde::{
5+
de::{Error, Visitor},
6+
Serializer,
7+
};
8+
9+
#[cfg(feature = "alloc")]
10+
use ::{alloc::vec::Vec, serde::Serialize};
11+
12+
#[cfg(not(feature = "alloc"))]
13+
use serde::ser::Error as SerError;
14+
15+
pub(crate) fn serialize_hex<S, T, const UPPERCASE: bool>(
16+
value: &T,
17+
serializer: S,
18+
) -> Result<S::Ok, S::Error>
19+
where
20+
S: Serializer,
21+
T: AsRef<[u8]>,
22+
{
23+
#[cfg(feature = "alloc")]
24+
if UPPERCASE {
25+
return base16ct::upper::encode_string(value.as_ref()).serialize(serializer);
26+
} else {
27+
return base16ct::lower::encode_string(value.as_ref()).serialize(serializer);
28+
}
29+
#[cfg(not(feature = "alloc"))]
30+
{
31+
let _ = value;
32+
let _ = serializer;
33+
return Err(S::Error::custom(
34+
"serializer is human readable, which requires the `alloc` crate feature",
35+
));
36+
}
37+
}
38+
39+
pub(crate) fn serialize_hex_lower_or_bin<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
40+
where
41+
S: Serializer,
42+
T: AsRef<[u8]>,
43+
{
44+
if serializer.is_human_readable() {
45+
serialize_hex::<_, _, false>(value, serializer)
46+
} else {
47+
serializer.serialize_bytes(value.as_ref())
48+
}
49+
}
50+
51+
/// Serialize the given type as upper case hex when using human-readable
52+
/// formats or binary if the format is binary.
53+
pub(crate) fn serialize_hex_upper_or_bin<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
54+
where
55+
S: Serializer,
56+
T: AsRef<[u8]>,
57+
{
58+
if serializer.is_human_readable() {
59+
serialize_hex::<_, _, true>(value, serializer)
60+
} else {
61+
serializer.serialize_bytes(value.as_ref())
62+
}
63+
}
64+
65+
pub(crate) trait LengthCheck {
66+
fn length_check(buffer_length: usize, data_length: usize) -> bool;
67+
fn expecting(
68+
formatter: &mut fmt::Formatter<'_>,
69+
data_type: &str,
70+
data_length: usize,
71+
) -> fmt::Result;
72+
}
73+
74+
pub(crate) struct StrIntoBufVisitor<'b, T: LengthCheck>(pub &'b mut [u8], pub PhantomData<T>);
75+
76+
impl<'de, 'b, T: LengthCheck> Visitor<'de> for StrIntoBufVisitor<'b, T> {
77+
type Value = ();
78+
79+
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
80+
T::expecting(formatter, "a string", self.0.len() * 2)
81+
}
82+
83+
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
84+
where
85+
E: Error,
86+
{
87+
if !T::length_check(self.0.len() * 2, v.len()) {
88+
return Err(Error::invalid_length(v.len(), &self));
89+
}
90+
// TODO: Map `base16ct::Error::InvalidLength` to `Error::invalid_length`.
91+
base16ct::mixed::decode(v, self.0)
92+
.map(|_| ())
93+
.map_err(E::custom)
94+
}
95+
}
96+
97+
#[cfg(feature = "alloc")]
98+
pub(crate) struct StrIntoVecVisitor;
99+
100+
#[cfg(feature = "alloc")]
101+
impl<'de> Visitor<'de> for StrIntoVecVisitor {
102+
type Value = Vec<u8>;
103+
104+
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
105+
write!(formatter, "a string")
106+
}
107+
108+
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
109+
where
110+
E: Error,
111+
{
112+
base16ct::mixed::decode_vec(v).map_err(E::custom)
113+
}
114+
}
115+
116+
pub(crate) struct SliceVisitor<'b, T: LengthCheck>(pub &'b mut [u8], pub PhantomData<T>);
117+
118+
impl<'de, 'b, T: LengthCheck> Visitor<'de> for SliceVisitor<'b, T> {
119+
type Value = ();
120+
121+
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
122+
T::expecting(formatter, "an array", self.0.len())
123+
}
124+
125+
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
126+
where
127+
E: Error,
128+
{
129+
// Workaround for
130+
// https://github.com/rust-lang/rfcs/blob/b1de05846d9bc5591d753f611ab8ee84a01fa500/text/2094-nll.md#problem-case-3-conditional-control-flow-across-functions
131+
if T::length_check(self.0.len(), v.len()) {
132+
let buffer = &mut self.0[..v.len()];
133+
buffer.copy_from_slice(v);
134+
return Ok(());
135+
}
136+
137+
Err(E::invalid_length(v.len(), &self))
138+
}
139+
140+
#[cfg(feature = "alloc")]
141+
fn visit_byte_buf<E>(self, mut v: Vec<u8>) -> Result<Self::Value, E>
142+
where
143+
E: Error,
144+
{
145+
// Workaround for
146+
// https://github.com/rust-lang/rfcs/blob/b1de05846d9bc5591d753f611ab8ee84a01fa500/text/2094-nll.md#problem-case-3-conditional-control-flow-across-functions
147+
if T::length_check(self.0.len(), v.len()) {
148+
let buffer = &mut self.0[..v.len()];
149+
buffer.swap_with_slice(&mut v);
150+
return Ok(());
151+
}
152+
153+
Err(E::invalid_length(v.len(), &self))
154+
}
155+
}
156+
157+
#[cfg(feature = "alloc")]
158+
pub(crate) struct VecVisitor;
159+
160+
#[cfg(feature = "alloc")]
161+
impl<'de> Visitor<'de> for VecVisitor {
162+
type Value = Vec<u8>;
163+
164+
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
165+
write!(formatter, "a bytestring")
166+
}
167+
168+
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
169+
where
170+
E: Error,
171+
{
172+
Ok(v.into())
173+
}
174+
175+
fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
176+
where
177+
E: Error,
178+
{
179+
Ok(v)
180+
}
181+
}

serdect/src/lib.rs

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -131,35 +131,7 @@
131131
extern crate alloc;
132132

133133
pub mod array;
134+
mod common;
134135
pub mod slice;
135136

136137
pub use serde;
137-
138-
use serde::Serializer;
139-
140-
#[cfg(not(feature = "alloc"))]
141-
use serde::ser::Error;
142-
143-
#[cfg(feature = "alloc")]
144-
use serde::Serialize;
145-
146-
fn serialize_hex<S, T, const UPPERCASE: bool>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
147-
where
148-
S: Serializer,
149-
T: AsRef<[u8]>,
150-
{
151-
#[cfg(feature = "alloc")]
152-
if UPPERCASE {
153-
return base16ct::upper::encode_string(value.as_ref()).serialize(serializer);
154-
} else {
155-
return base16ct::lower::encode_string(value.as_ref()).serialize(serializer);
156-
}
157-
#[cfg(not(feature = "alloc"))]
158-
{
159-
let _ = value;
160-
let _ = serializer;
161-
return Err(S::Error::custom(
162-
"serializer is human readable, which requires the `alloc` crate feature",
163-
));
164-
}
165-
}

0 commit comments

Comments
 (0)