Skip to content

Commit d84ec9f

Browse files
authored
Allow configuring the default value for primitives types (#39)
* Allow configuring the default value for primitives types during fake deserialization * add documentation
1 parent d842f82 commit d84ec9f

File tree

5 files changed

+142
-18
lines changed

5 files changed

+142
-18
lines changed

serde-reflection/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,13 @@ use the crate [`serde-name`](https://crates.io/crates/serde-name) and its adapte
105105
terminate. (Work around: re-order the variants. For instance `enum List {
106106
Some(Box<List>), None}` must be rewritten `enum List { None, Some(Box<List>)}`.)
107107

108+
* Certain standard types such as `std::num::NonZeroU8` may not be tracked as a
109+
container and appear simply as their underlying primitive type (e.g. `u8`) in the
110+
formats. This loss of information makes it difficult to use `trace_value` to work
111+
around deserialization invariants (see example below). As a work around, you may
112+
override the default for the primitive type using `TracerConfig` (e.g. `let config =
113+
TracerConfig::default().default_u8_value(1);`).
114+
108115
### Security CAVEAT
109116

110117
At this time, `HashSet<T>` and `BTreeSet<T>` are treated as sequences (i.e. vectors)

serde-reflection/src/de.rs

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,143 +50,143 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> {
5050
V: Visitor<'de>,
5151
{
5252
self.format.unify(Format::Bool)?;
53-
visitor.visit_bool(false)
53+
visitor.visit_bool(self.tracer.config.default_bool_value)
5454
}
5555

5656
fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
5757
where
5858
V: Visitor<'de>,
5959
{
6060
self.format.unify(Format::I8)?;
61-
visitor.visit_i8(0)
61+
visitor.visit_i8(self.tracer.config.default_i8_value)
6262
}
6363

6464
fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
6565
where
6666
V: Visitor<'de>,
6767
{
6868
self.format.unify(Format::I16)?;
69-
visitor.visit_i16(0)
69+
visitor.visit_i16(self.tracer.config.default_i16_value)
7070
}
7171

7272
fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
7373
where
7474
V: Visitor<'de>,
7575
{
7676
self.format.unify(Format::I32)?;
77-
visitor.visit_i32(0)
77+
visitor.visit_i32(self.tracer.config.default_i32_value)
7878
}
7979

8080
fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
8181
where
8282
V: Visitor<'de>,
8383
{
8484
self.format.unify(Format::I64)?;
85-
visitor.visit_i64(0)
85+
visitor.visit_i64(self.tracer.config.default_i64_value)
8686
}
8787

8888
fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value>
8989
where
9090
V: Visitor<'de>,
9191
{
9292
self.format.unify(Format::I128)?;
93-
visitor.visit_i128(0)
93+
visitor.visit_i128(self.tracer.config.default_i128_value)
9494
}
9595

9696
fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
9797
where
9898
V: Visitor<'de>,
9999
{
100100
self.format.unify(Format::U8)?;
101-
visitor.visit_u8(0)
101+
visitor.visit_u8(self.tracer.config.default_u8_value)
102102
}
103103

104104
fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
105105
where
106106
V: Visitor<'de>,
107107
{
108108
self.format.unify(Format::U16)?;
109-
visitor.visit_u16(0)
109+
visitor.visit_u16(self.tracer.config.default_u16_value)
110110
}
111111

112112
fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
113113
where
114114
V: Visitor<'de>,
115115
{
116116
self.format.unify(Format::U32)?;
117-
visitor.visit_u32(0)
117+
visitor.visit_u32(self.tracer.config.default_u32_value)
118118
}
119119

120120
fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
121121
where
122122
V: Visitor<'de>,
123123
{
124124
self.format.unify(Format::U64)?;
125-
visitor.visit_u64(0)
125+
visitor.visit_u64(self.tracer.config.default_u64_value)
126126
}
127127

128128
fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value>
129129
where
130130
V: Visitor<'de>,
131131
{
132132
self.format.unify(Format::U128)?;
133-
visitor.visit_u128(0)
133+
visitor.visit_u128(self.tracer.config.default_u128_value)
134134
}
135135

136136
fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
137137
where
138138
V: Visitor<'de>,
139139
{
140140
self.format.unify(Format::F32)?;
141-
visitor.visit_f32(0.0)
141+
visitor.visit_f32(self.tracer.config.default_f32_value)
142142
}
143143

144144
fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
145145
where
146146
V: Visitor<'de>,
147147
{
148148
self.format.unify(Format::F64)?;
149-
visitor.visit_f64(0.0)
149+
visitor.visit_f64(self.tracer.config.default_f64_value)
150150
}
151151

152152
fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
153153
where
154154
V: Visitor<'de>,
155155
{
156156
self.format.unify(Format::Char)?;
157-
visitor.visit_char('A')
157+
visitor.visit_char(self.tracer.config.default_char_value)
158158
}
159159

160160
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
161161
where
162162
V: Visitor<'de>,
163163
{
164164
self.format.unify(Format::Str)?;
165-
visitor.visit_borrowed_str("")
165+
visitor.visit_borrowed_str(self.tracer.config.default_borrowed_str_value)
166166
}
167167

168168
fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
169169
where
170170
V: Visitor<'de>,
171171
{
172172
self.format.unify(Format::Str)?;
173-
visitor.visit_string(String::new())
173+
visitor.visit_string(self.tracer.config.default_string_value.clone())
174174
}
175175

176176
fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
177177
where
178178
V: Visitor<'de>,
179179
{
180180
self.format.unify(Format::Bytes)?;
181-
visitor.visit_borrowed_bytes(b"")
181+
visitor.visit_borrowed_bytes(self.tracer.config.default_borrowed_bytes_value)
182182
}
183183

184184
fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
185185
where
186186
V: Visitor<'de>,
187187
{
188188
self.format.unify(Format::Bytes)?;
189-
visitor.visit_byte_buf(Vec::new())
189+
visitor.visit_byte_buf(self.tracer.config.default_byte_buf_value.clone())
190190
}
191191

192192
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>

serde-reflection/src/lib.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,13 @@
108108
//! terminate. (Work around: re-order the variants. For instance `enum List {
109109
//! Some(Box<List>), None}` must be rewritten `enum List { None, Some(Box<List>)}`.)
110110
//!
111+
//! * Certain standard types such as `std::num::NonZeroU8` may not be tracked as a
112+
//! container and appear simply as their underlying primitive type (e.g. `u8`) in the
113+
//! formats. This loss of information makes it difficult to use `trace_value` to work
114+
//! around deserialization invariants (see example below). As a work around, you may
115+
//! override the default for the primitive type using `TracerConfig` (e.g. `let config =
116+
//! TracerConfig::default().default_u8_value(1);`).
117+
//!
111118
//! ## Security CAVEAT
112119
//!
113120
//! At this time, `HashSet<T>` and `BTreeSet<T>` are treated as sequences (i.e. vectors)

serde-reflection/src/trace.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,24 @@ pub struct TracerConfig {
5757
pub(crate) record_samples_for_newtype_structs: bool,
5858
pub(crate) record_samples_for_tuple_structs: bool,
5959
pub(crate) record_samples_for_structs: bool,
60+
pub(crate) default_bool_value: bool,
61+
pub(crate) default_u8_value: u8,
62+
pub(crate) default_u16_value: u16,
63+
pub(crate) default_u32_value: u32,
64+
pub(crate) default_u64_value: u64,
65+
pub(crate) default_u128_value: u128,
66+
pub(crate) default_i8_value: i8,
67+
pub(crate) default_i16_value: i16,
68+
pub(crate) default_i32_value: i32,
69+
pub(crate) default_i64_value: i64,
70+
pub(crate) default_i128_value: i128,
71+
pub(crate) default_f32_value: f32,
72+
pub(crate) default_f64_value: f64,
73+
pub(crate) default_char_value: char,
74+
pub(crate) default_borrowed_str_value: &'static str,
75+
pub(crate) default_string_value: String,
76+
pub(crate) default_borrowed_bytes_value: &'static [u8],
77+
pub(crate) default_byte_buf_value: Vec<u8>,
6078
}
6179

6280
impl Default for TracerConfig {
@@ -67,10 +85,38 @@ impl Default for TracerConfig {
6785
record_samples_for_newtype_structs: true,
6886
record_samples_for_tuple_structs: false,
6987
record_samples_for_structs: false,
88+
default_bool_value: false,
89+
default_u8_value: 0,
90+
default_u16_value: 0,
91+
default_u32_value: 0,
92+
default_u64_value: 0,
93+
default_u128_value: 0,
94+
default_i8_value: 0,
95+
default_i16_value: 0,
96+
default_i32_value: 0,
97+
default_i64_value: 0,
98+
default_i128_value: 0,
99+
default_f32_value: 0.0,
100+
default_f64_value: 0.0,
101+
default_char_value: 'A',
102+
default_borrowed_str_value: "",
103+
default_string_value: String::new(),
104+
default_borrowed_bytes_value: b"",
105+
default_byte_buf_value: Vec::new(),
70106
}
71107
}
72108
}
73109

110+
macro_rules! define_default_value_setter {
111+
($method:ident, $ty:ty) => {
112+
/// The default serialized value for this primitive type.
113+
pub fn $method(mut self, value: $ty) -> Self {
114+
self.$method = value;
115+
self
116+
}
117+
};
118+
}
119+
74120
impl TracerConfig {
75121
/// Whether to trace the human readable encoding of (de)serialization.
76122
#[allow(clippy::wrong_self_convention)]
@@ -96,6 +142,25 @@ impl TracerConfig {
96142
self.record_samples_for_structs = value;
97143
self
98144
}
145+
146+
define_default_value_setter!(default_bool_value, bool);
147+
define_default_value_setter!(default_u8_value, u8);
148+
define_default_value_setter!(default_u16_value, u16);
149+
define_default_value_setter!(default_u32_value, u32);
150+
define_default_value_setter!(default_u64_value, u64);
151+
define_default_value_setter!(default_u128_value, u128);
152+
define_default_value_setter!(default_i8_value, i8);
153+
define_default_value_setter!(default_i16_value, i16);
154+
define_default_value_setter!(default_i32_value, i32);
155+
define_default_value_setter!(default_i64_value, i64);
156+
define_default_value_setter!(default_i128_value, i128);
157+
define_default_value_setter!(default_f32_value, f32);
158+
define_default_value_setter!(default_f64_value, f64);
159+
define_default_value_setter!(default_char_value, char);
160+
define_default_value_setter!(default_borrowed_str_value, &'static str);
161+
define_default_value_setter!(default_string_value, String);
162+
define_default_value_setter!(default_borrowed_bytes_value, &'static [u8]);
163+
define_default_value_setter!(default_byte_buf_value, Vec<u8>);
99164
}
100165

101166
impl Tracer {

serde-reflection/tests/serde.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,3 +457,48 @@ fn test_repeated_tracing() {
457457
))))))
458458
);
459459
}
460+
461+
#[test]
462+
fn test_default_value_for_primitive_types() {
463+
let config = TracerConfig::default()
464+
.default_u8_value(1)
465+
.default_u16_value(2)
466+
.default_u32_value(3)
467+
.default_u64_value(4)
468+
.default_u128_value(5)
469+
.default_i8_value(6)
470+
.default_i16_value(7)
471+
.default_i32_value(8)
472+
.default_i64_value(9)
473+
.default_i128_value(10)
474+
.default_string_value("A string".into())
475+
.default_borrowed_str_value("A borrowed str");
476+
let mut tracer = Tracer::new(config);
477+
let samples = Samples::new();
478+
479+
let (format, value) = tracer
480+
.trace_type_once::<std::num::NonZeroU8>(&samples)
481+
.unwrap();
482+
assert_eq!(format, Format::U8); // Not a container
483+
assert_eq!(value.get(), 1);
484+
485+
let (format, value) = tracer.trace_type_once::<u8>(&samples).unwrap();
486+
assert_eq!(format, Format::U8);
487+
assert_eq!(value, 1);
488+
489+
let (format, value) = tracer.trace_type_once::<u16>(&samples).unwrap();
490+
assert_eq!(format, Format::U16);
491+
assert_eq!(value, 2);
492+
493+
let (format, value) = tracer.trace_type_once::<i128>(&samples).unwrap();
494+
assert_eq!(format, Format::I128);
495+
assert_eq!(value, 10);
496+
497+
let (format, value) = tracer.trace_type_once::<String>(&samples).unwrap();
498+
assert_eq!(format, Format::Str);
499+
assert_eq!(value.as_str(), "A string");
500+
501+
let (format, value) = tracer.trace_type_once::<&str>(&samples).unwrap();
502+
assert_eq!(format, Format::Str);
503+
assert_eq!(value, "A borrowed str");
504+
}

0 commit comments

Comments
 (0)