From 64d8b326da5029c04eadf8c263af38bbf4311dec Mon Sep 17 00:00:00 2001 From: jrb0001 Date: Sat, 26 Apr 2025 14:14:05 +0200 Subject: [PATCH] Add attributes for generating Eq/PartialEq/Ord/PartialOrd/Hash --- bitbybit-tests/src/bitfield_tests.rs | 226 +++++++++++++++++++++++- bitbybit/README.md | 20 ++- bitbybit/examples/simple.rs | 2 +- bitbybit/src/bitfield/codegen.rs | 26 +-- bitbybit/src/bitfield/codegen_traits.rs | 177 +++++++++++++++++++ bitbybit/src/bitfield/mod.rs | 89 +++++++--- 6 files changed, 488 insertions(+), 52 deletions(-) create mode 100644 bitbybit/src/bitfield/codegen_traits.rs diff --git a/bitbybit-tests/src/bitfield_tests.rs b/bitbybit-tests/src/bitfield_tests.rs index 29543aa..cddc51d 100644 --- a/bitbybit-tests/src/bitfield_tests.rs +++ b/bitbybit-tests/src/bitfield_tests.rs @@ -1,9 +1,10 @@ use arbitrary_int::Number; -use std::fmt::Debug; - use arbitrary_int::{u1, u12, u13, u14, u2, u24, u3, u30, u4, u48, u5, u57, u7}; use bitbybit::bitenum; use bitbybit::bitfield; +use std::cmp::Ordering; +use std::fmt::Debug; +use std::hash::{DefaultHasher, Hash, Hasher}; #[test] fn test_construction() { @@ -1678,3 +1679,224 @@ fn test_debug_impl() { let display_str = format!("{:?}", test); assert_eq!(display_str, "Test { upper: 31, lower: 47 }"); } + +#[test] +fn test_eq_impl() { + #[bitfield(u16, eq, partial_eq)] + struct Test { + #[bits(12..=15, rw)] + upper: u4, + + #[bits(0..=3, rw)] + lower: u4, + } + fn assert_impl(_: impl Eq) {} + assert_impl(Test::new_with_raw_value(0x1F2F)); +} + +#[test] +fn test_partial_eq_impl() { + #[bitfield(u16, debug, partial_eq)] + struct Test { + #[bits(12..=15, rw)] + upper: u4, + + #[bits(0..=3, rw)] + lower: u4, + } + let a = Test::new_with_raw_value(0b0101_0101_0101_0101); + let b = Test::new_with_raw_value(0b0101_1010_1010_0101); + let c = Test::new_with_raw_value(0b1010_0101_0101_1010); + assert_eq!(a, b); + assert_ne!(a, c); +} + +#[test] +fn test_ord_impl() { + #[bitfield(u16, default = 0, eq, partial_eq, ord, partial_ord)] + struct Test { + #[bits(0..=7, rw)] + lower: u8, + + #[bits(8..=15, rw)] + upper: u8, + } + + fn create(lower: u8, upper: u8) -> Test { + Test::builder().with_lower(lower).with_upper(upper).build() + } + + let a = create(8, 1); + let b = create(8, 2); + let c = create(8, 3); + assert_eq!(Ordering::Less, b.cmp(&c)); + assert_eq!(Ordering::Equal, b.cmp(&b)); + assert_eq!(Ordering::Greater, b.cmp(&a)); + + let a = create(7, 2); + let b = create(8, 2); + let c = create(9, 2); + assert_eq!(Ordering::Less, b.cmp(&c)); + assert_eq!(Ordering::Equal, b.cmp(&b)); + assert_eq!(Ordering::Greater, b.cmp(&a)); + + let a = create(7, 3); + let b = create(8, 2); + let c = create(9, 1); + assert_eq!(Ordering::Less, b.cmp(&c)); + assert_eq!(Ordering::Equal, b.cmp(&b)); + assert_eq!(Ordering::Greater, b.cmp(&a)); +} + +#[test] +fn test_partial_ord_impl() { + #[bitfield(u8, default = 0)] + struct PartialOrdWrapper { + #[bits(0..=7, rw)] + inner: u8, + } + + impl PartialEq for PartialOrdWrapper { + fn eq(&self, other: &Self) -> bool { + self.partial_cmp(other) == Some(Ordering::Equal) + } + } + impl PartialOrd for PartialOrdWrapper { + fn partial_cmp(&self, other: &Self) -> Option { + if self.inner() == 0 || other.inner() == 0 { + None + } else { + Some(self.inner().cmp(&other.inner())) + } + } + } + + #[bitfield(u16, default = 0, partial_eq, partial_ord)] + struct Test { + #[bits(0..=7, rw)] + lower: PartialOrdWrapper, + + #[bits(8..=15, rw)] + upper: PartialOrdWrapper, + } + + fn create(lower: u8, upper: u8) -> Test { + Test::builder() + .with_lower(PartialOrdWrapper::builder().with_inner(lower).build()) + .with_upper(PartialOrdWrapper::builder().with_inner(upper).build()) + .build() + } + + let a = create(8, 1); + let b = create(8, 2); + let c = create(8, 3); + assert_eq!(Some(Ordering::Less), b.partial_cmp(&c)); + assert_eq!(Some(Ordering::Equal), b.partial_cmp(&b)); + assert_eq!(Some(Ordering::Greater), b.partial_cmp(&a)); + + let a = create(7, 2); + let b = create(8, 2); + let c = create(9, 2); + assert_eq!(Some(Ordering::Less), b.partial_cmp(&c)); + assert_eq!(Some(Ordering::Equal), b.partial_cmp(&b)); + assert_eq!(Some(Ordering::Greater), b.partial_cmp(&a)); + + let a = create(7, 3); + let b = create(8, 2); + let c = create(9, 1); + assert_eq!(Some(Ordering::Less), b.partial_cmp(&c)); + assert_eq!(Some(Ordering::Equal), b.partial_cmp(&b)); + assert_eq!(Some(Ordering::Greater), b.partial_cmp(&a)); + + let a = create(0, 1); + let b = create(0, 2); + let c = create(0, 3); + assert_eq!(None, b.partial_cmp(&c)); + assert_eq!(None, b.partial_cmp(&b)); + assert_eq!(None, b.partial_cmp(&a)); + + let a = create(7, 0); + let b = create(8, 0); + let c = create(9, 0); + assert_eq!(Some(Ordering::Less), b.partial_cmp(&c)); + assert_eq!(None, b.partial_cmp(&b)); + assert_eq!(Some(Ordering::Greater), b.partial_cmp(&a)); +} + +#[test] +fn test_hash_impl() { + #[bitfield(u16, debug, hash)] + struct Test { + #[bits(12..=15, rw)] + upper: u4, + + #[bits(0..=3, rw)] + lower: u4, + } + + fn hash(value: Test) -> u64 { + let mut hasher = DefaultHasher::new(); + value.hash(&mut hasher); + hasher.finish() + } + + let a = Test::new_with_raw_value(0b0101_0101_0101_0101); + let b = Test::new_with_raw_value(0b0101_1010_1010_0101); + let c = Test::new_with_raw_value(0b1010_0101_0101_1010); + assert_eq!(hash(a), hash(b)); + assert_ne!(hash(a), hash(c)); +} + +#[test] +fn test_eq_ord_consistency_with_enum_nonexhaustive() { + #[bitenum(u2, exhaustive = false)] + #[derive(Eq, PartialEq, Ord, PartialOrd, Debug)] + pub enum NonExhaustiveEnum { + One = 0b01, + Two = 0b10, + } + + #[bitfield(u64, default = 0, debug, eq, partial_eq, ord, partial_ord)] + pub struct BitfieldWithEnumNonExhaustive { + #[bits(2..=3, rw)] + e2: Option, + } + + fn assert( + expected: Ordering, + a: BitfieldWithEnumNonExhaustive, + b: BitfieldWithEnumNonExhaustive, + ) { + assert_eq!(expected, a.cmp(&b), "{:?}.cmp({:?})", a, b); + assert_eq!(Some(expected), a.partial_cmp(&b), "{:?}.partial_cmp({:?})", a, b); + + if expected == Ordering::Equal { + assert_eq!(a, b, "{:?}.eq({:?})", a, b); + } else { + assert_ne!(a, b, "{:?}.ne({:?})", a, b); + } + } + + let zero = BitfieldWithEnumNonExhaustive::new_with_raw_value(0b0010); + let one = BitfieldWithEnumNonExhaustive::new_with_raw_value(0b0110); + let two = BitfieldWithEnumNonExhaustive::new_with_raw_value(0b1010); + let three = BitfieldWithEnumNonExhaustive::new_with_raw_value(0b1110); + + // From Result.cmp(): Ok(One) < Ok(Two) < Err(0) < Err(3) + assert(Ordering::Equal, zero, zero); + assert(Ordering::Greater, zero, one); + assert(Ordering::Greater, zero, two); + assert(Ordering::Less, zero, three); + assert(Ordering::Less, one, zero); + assert(Ordering::Equal, one, one); + assert(Ordering::Less, one, two); + assert(Ordering::Less, one, three); + assert(Ordering::Less, two, zero); + assert(Ordering::Greater, two, one); + assert(Ordering::Equal, two, two); + assert(Ordering::Less, two, three); + assert(Ordering::Greater, three, zero); + assert(Ordering::Greater, three, one); + assert(Ordering::Greater, three, two); + assert(Ordering::Equal, three, three); +} diff --git a/bitbybit/README.md b/bitbybit/README.md index a616537..2b4b3c1 100644 --- a/bitbybit/README.md +++ b/bitbybit/README.md @@ -187,14 +187,24 @@ immediates in a way that they have to be reassembled. This can be achieved like } ``` -## Debug +## Debug / Eq / PartialEq / Ord / PartialOrd / Hash -The `bitfield` macro can generate a `Debug` implementation for you which prints -the `Debug` implementation of the inner fields. You can do this using the `debug` -specifier: +The `bitfield` macro can generate implementations of a few core traits for you. +You can do this using the specifiers from the following table: + +| Trait | Specifier | Description | +|--------------|---------------|--------------------------------------------------------------------------| +| `Debug` | `debug` | Prints the `Debug` implementation of all fields. | +| `Eq` | `eq` | | +| `PartialEq` | `partial_eq` | Unused bits are ignored. | +| `Ord` | `ord` | Uses the `Ord` implementation of all fields in declaration order. | +| `PartialOrd` | `partial_ord` | Uses the `PartialOrd` implementation of all fields in declaration order. | +| `Hash` | `hash` | Unused bits are ignored. | + +Example: ```rs -#[bitfield(u32, debug)] +#[bitfield(u32, debug, eq, partial_eq, ord, partial_ord, hash)] struct GICD_TYPER { #[bits(11..=15, r)] lspi: u5, diff --git a/bitbybit/examples/simple.rs b/bitbybit/examples/simple.rs index e9201c0..a1f458e 100644 --- a/bitbybit/examples/simple.rs +++ b/bitbybit/examples/simple.rs @@ -1,7 +1,7 @@ use arbitrary_int::u4; use bitbybit::bitfield; -#[bitfield(u32, debug)] +#[bitfield(u32, debug, eq, partial_eq, ord, partial_ord, hash)] pub struct BitfieldU32 { #[bits(28..=31, rw)] val3: u4, diff --git a/bitbybit/src/bitfield/codegen.rs b/bitbybit/src/bitfield/codegen.rs index f119e85..83e1236 100644 --- a/bitbybit/src/bitfield/codegen.rs +++ b/bitbybit/src/bitfield/codegen.rs @@ -389,13 +389,9 @@ pub fn make_builder( if ranges_have_self_overlap(&field_definition.ranges, array_stride, array_count) { return (quote! {}, Vec::new()); } - let mut mask = 0; + let mask = super::used_bits_mask_for_field(field_definition); let mut array_setters = Vec::with_capacity(array_count); for i in 0..array_count { - mask |= field_definition.ranges.iter().fold(0u128, |a, range| { - a | (((1u128 << range.len()) - 1) << (range.start + i * array_stride)) - }); - array_setters.push(quote! { .#with_name(#i, value[#i]) }); } let value_transform = quote!(self.0 #( #array_setters )*); @@ -403,24 +399,12 @@ pub fn make_builder( (mask, value_transform, array_type) } else { - let mask = if field_definition.ranges.len() == 1 { - if field_definition.ranges[0].len() == 128 { - u128::MAX - } else { - ((1u128 << field_definition.ranges[0].len()) - 1) - << (field_definition.ranges[0].start) - } - } else { - if ranges_have_self_overlap(&field_definition.ranges, 0, 0) { - return (quote! {}, Vec::new()); - } - field_definition.ranges.iter().fold(0u128, |a, range| { - a | (((1u128 << range.len()) - 1) << (range.start)) - }) - }; + if ranges_have_self_overlap(&field_definition.ranges, 0, 0) { + return (quote! {}, Vec::new()); + } ( - mask, + super::used_bits_mask_for_field(field_definition), quote! { self.0.#with_name(value)}, quote! { #setter_type }, ) diff --git a/bitbybit/src/bitfield/codegen_traits.rs b/bitbybit/src/bitfield/codegen_traits.rs new file mode 100644 index 0000000..b51b48c --- /dev/null +++ b/bitbybit/src/bitfield/codegen_traits.rs @@ -0,0 +1,177 @@ +use crate::bitfield::{BitfieldAttributes, FieldDefinition}; +use proc_macro2::{Ident, TokenStream, TokenTree}; +use quote::quote; + +pub fn generate( + struct_name: &Ident, + bitfield_attrs: &BitfieldAttributes, + field_definitions: &[FieldDefinition], +) -> impl Iterator { + let mask = syn::parse_str::( + format!( + "{:#x}", + super::used_bits_mask_for_struct(&field_definitions) + ) + .as_str(), + ) + .unwrap(); + + let debug_trait = generate_debug_trait(struct_name, bitfield_attrs, field_definitions); + let eq_trait = generate_eq_trait(struct_name, bitfield_attrs); + let partial_eq_trait = generate_partial_eq_trait(struct_name, bitfield_attrs, &mask); + let ord_trait = generate_ord_trait(struct_name, bitfield_attrs, field_definitions); + let partial_ord_trait = + generate_partial_ord_trait(struct_name, bitfield_attrs, field_definitions); + let hash_trait = generate_hash_trait(struct_name, bitfield_attrs, mask); + + debug_trait + .into_iter() + .chain(eq_trait.into_iter()) + .chain(partial_eq_trait.into_iter()) + .chain(ord_trait.into_iter()) + .chain(partial_ord_trait.into_iter()) + .chain(hash_trait.into_iter()) +} + +fn generate_debug_trait( + struct_name: &Ident, + bitfield_attrs: &BitfieldAttributes, + field_definitions: &[FieldDefinition], +) -> Option { + if bitfield_attrs.debug_trait { + let debug_fields: Vec = field_definitions + .iter() + .map(|field| { + let field_name = &field.field_name; + quote! { + .field(stringify!(#field_name), &self.#field_name()) + } + }) + .collect(); + Some(quote! { + impl ::core::fmt::Debug for #struct_name { + fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { + f.debug_struct(stringify!(#struct_name)) + #(#debug_fields)* + .finish() + } + } + }) + } else { + None + } +} + +fn generate_eq_trait( + struct_name: &Ident, + bitfield_attrs: &BitfieldAttributes, +) -> Option { + if bitfield_attrs.eq_trait { + Some(quote! { + impl ::core::cmp::Eq for #struct_name {} + }) + } else { + None + } +} + +fn generate_partial_eq_trait( + struct_name: &Ident, + bitfield_attrs: &BitfieldAttributes, + mask: &TokenTree, +) -> Option { + if bitfield_attrs.partial_eq_trait { + Some(quote! { + impl ::core::cmp::PartialEq for #struct_name { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.raw_value & #mask == other.raw_value & #mask + } + } + }) + } else { + None + } +} + +fn generate_ord_trait( + struct_name: &Ident, + bitfield_attrs: &BitfieldAttributes, + field_definitions: &[FieldDefinition], +) -> Option { + if bitfield_attrs.ord_trait { + let ord_fields: Vec = field_definitions + .iter() + .map(|field| { + let field_name = &field.field_name; + quote! { + match self.#field_name().cmp(&other.#field_name()) { + core::cmp::Ordering::Equal => {}, + cmp => return cmp, + } + } + }) + .collect(); + Some(quote! { + impl ::core::cmp::Ord for #struct_name { + fn cmp(&self, other: &Self) -> ::core::cmp::Ordering { + #(#ord_fields)* + ::core::cmp::Ordering::Equal + } + } + }) + } else { + None + } +} + +fn generate_partial_ord_trait( + struct_name: &Ident, + bitfield_attrs: &BitfieldAttributes, + field_definitions: &[FieldDefinition], +) -> Option { + if bitfield_attrs.partial_ord_trait { + let partial_ord_fields: Vec = field_definitions + .iter() + .map(|field| { + let field_name = &field.field_name; + quote! { + match self.#field_name().partial_cmp(&other.#field_name()) { + ::core::option::Option::Some(core::cmp::Ordering::Equal) => {}, + ::core::option::Option::Some(cmp) => return ::core::option::Option::Some(cmp), + ::core::option::Option::None => return ::core::option::Option::None, + } + } + }) + .collect(); + Some(quote! { + impl ::core::cmp::PartialOrd for #struct_name { + fn partial_cmp(&self, other: &Self) -> ::core::option::Option<::core::cmp::Ordering> { + #(#partial_ord_fields)* + ::core::option::Option::Some(::core::cmp::Ordering::Equal) + } + } + }) + } else { + None + } +} + +fn generate_hash_trait( + struct_name: &Ident, + bitfield_attrs: &BitfieldAttributes, + mask: TokenTree, +) -> Option { + if bitfield_attrs.hash_trait { + Some(quote! { + impl ::core::hash::Hash for #struct_name { + #[inline] + fn hash(&self, state: &mut H) { + (self.raw_value & #mask).hash(state) + } + } + }) + } else { + None + } +} diff --git a/bitbybit/src/bitfield/mod.rs b/bitbybit/src/bitfield/mod.rs index 386314a..1c768c5 100644 --- a/bitbybit/src/bitfield/mod.rs +++ b/bitbybit/src/bitfield/mod.rs @@ -1,11 +1,11 @@ mod codegen; +mod codegen_traits; mod parsing; use proc_macro::Span; use proc_macro::TokenStream; use proc_macro2::Ident; use proc_macro2::TokenStream as TokenStream2; -use quote::TokenStreamExt; use quote::{quote, ToTokens}; use std::ops::Range; use std::str::FromStr; @@ -95,6 +95,11 @@ struct BitfieldAttributes { pub base_type: Option, pub default_val: Option, pub debug_trait: bool, + pub eq_trait: bool, + pub partial_eq_trait: bool, + pub ord_trait: bool, + pub partial_ord_trait: bool, + pub hash_trait: bool, } impl BitfieldAttributes { @@ -129,6 +134,26 @@ impl BitfieldAttributes { self.debug_trait = true; return Ok(()); } + if meta.path.is_ident("eq") { + self.eq_trait = true; + return Ok(()); + } + if meta.path.is_ident("partial_eq") { + self.partial_eq_trait = true; + return Ok(()); + } + if meta.path.is_ident("ord") { + self.ord_trait = true; + return Ok(()); + } + if meta.path.is_ident("partial_ord") { + self.partial_ord_trait = true; + return Ok(()); + } + if meta.path.is_ident("hash") { + self.hash_trait = true; + return Ok(()); + } Ok(()) } } @@ -234,27 +259,6 @@ pub fn bitfield(args: TokenStream, input: TokenStream) -> TokenStream { (quote! {}, quote! {}) }; - let mut debug_trait = TokenStream2::new(); - if bitfield_attrs.debug_trait { - let debug_fields: Vec = field_definitions - .iter() - .map(|field| { - let field_name = &field.field_name; - quote! { - .field(stringify!(#field_name), &self.#field_name()) - } - }) - .collect(); - debug_trait.append_all(quote! { - impl ::core::fmt::Debug for #struct_name { - fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { - f.debug_struct(stringify!(#struct_name)) - #(#debug_fields)* - .finish() - } - } - }); - } let (new_with_constructor, new_with_builder_chain) = codegen::make_builder( &struct_name, bitfield_attrs.default_val.is_some(), @@ -293,6 +297,8 @@ pub fn bitfield(args: TokenStream, input: TokenStream) -> TokenStream { zero ); + let traits = codegen_traits::generate(&struct_name, &bitfield_attrs, &field_definitions); + let expanded = quote! { #[derive(Copy, Clone)] #[repr(C)] @@ -322,7 +328,7 @@ pub fn bitfield(args: TokenStream, input: TokenStream) -> TokenStream { #( #accessors )* } #default_trait - #debug_trait + #( #traits )* #( #new_with_builder_chain )* }; //println!("Expanded: {}", expanded.to_string()); @@ -359,6 +365,43 @@ fn setter_name(field_name: &Ident) -> Ident { .unwrap_or_else(|_| panic!("bitfield!: Error creating setter name")) } +fn used_bits_mask_for_struct(field_definitions: &[FieldDefinition]) -> u128 { + field_definitions + .iter() + .map(used_bits_mask_for_field) + .fold(0u128, |a, b| a | b) +} + +fn used_bits_mask_for_field(field_definition: &FieldDefinition) -> u128 { + if let Some(array) = field_definition.array { + let array_count = array.0; + let array_stride = array.1; + let mut mask = 0; + for i in 0..array_count { + mask |= field_definition.ranges.iter().fold(0u128, |a, range| { + a | (((1u128 << range.len()) - 1) << (range.start + i * array_stride)) + }); + } + + mask + } else { + let mask = if field_definition.ranges.len() == 1 { + if field_definition.ranges[0].len() == 128 { + u128::MAX + } else { + ((1u128 << field_definition.ranges[0].len()) - 1) + << (field_definition.ranges[0].start) + } + } else { + field_definition.ranges.iter().fold(0u128, |a, range| { + a | (((1u128 << range.len()) - 1) << (range.start)) + }) + }; + + mask + } +} + struct FieldDefinition { field_name: Ident, ranges: Vec>,