From 2e8c4c0be75a63c67bd6e576b97250df0a6b67eb Mon Sep 17 00:00:00 2001 From: Casper Meijn Date: Thu, 29 Aug 2024 11:10:01 +0200 Subject: [PATCH] feat: derive Eq and Hash trait for messages where possible Integer and bytes types can be compared using trait Eq. Some generated Rust structs can also have this property by deriving the Eq trait. Automatically derive Eq and Hash for: - messages that only have fields with integer or bytes types - messages where all field types also implement Eq and Hash - the Rust enum for one-of fields, where all fields implement Eq and Hash Generated code for Protobuf enums already derives Eq and Hash. BREAKING CHANGE: `prost-build` will automatically derive `trait Eq` and `trait Hash` for types where all field support those as well. If you manually `impl Eq` and/or `impl Hash` for generated types, then you need to remove the manual implementation. If you use `type_attribute` to `derive(Eq)` and/or `derive(Hash)`, then you need to remove those. --- prost-build/src/code_generator.rs | 18 +++++++- prost-build/src/context.rs | 45 +++++++++++++++++++ .../_expected_field_attributes.rs | 10 ++--- .../_expected_field_attributes_formatted.rs | 10 ++--- .../helloworld/_expected_helloworld.rs | 4 +- .../_expected_helloworld_formatted.rs | 4 +- prost-types/src/compiler.rs | 2 +- prost-types/src/duration.rs | 8 ---- prost-types/src/protobuf.rs | 24 +++++----- prost-types/src/timestamp.rs | 13 ------ tests/build.rs | 2 +- tests/single-include/src/outdir/outdir.rs | 2 +- 12 files changed, 90 insertions(+), 52 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 15048d54a..ee24e10ce 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -223,12 +223,17 @@ impl<'b> CodeGenerator<'_, 'b> { self.append_message_attributes(&fq_message_name); self.push_indent(); self.buf.push_str(&format!( - "#[derive(Clone, {}PartialEq, {}::Message)]\n", + "#[derive(Clone, {}PartialEq, {}{}::Message)]\n", if self.context.can_message_derive_copy(&fq_message_name) { "Copy, " } else { "" }, + if self.context.can_message_derive_eq(&fq_message_name) { + "Eq, Hash, " + } else { + "" + }, self.context.prost_path() )); self.append_skip_debug(&fq_message_name); @@ -596,9 +601,18 @@ impl<'b> CodeGenerator<'_, 'b> { self.context .can_field_derive_copy(fq_message_name, &field.descriptor) }); + let can_oneof_derive_eq = oneof.fields.iter().all(|field| { + self.context + .can_field_derive_eq(fq_message_name, &field.descriptor) + }); self.buf.push_str(&format!( - "#[derive(Clone, {}PartialEq, {}::Oneof)]\n", + "#[derive(Clone, {}PartialEq, {}{}::Oneof)]\n", if can_oneof_derive_copy { "Copy, " } else { "" }, + if can_oneof_derive_eq { + "Eq, Hash, " + } else { + "" + }, self.context.prost_path() )); self.append_skip_debug(fq_message_name); diff --git a/prost-build/src/context.rs b/prost-build/src/context.rs index 89b24b9d8..6a4d2bcaa 100644 --- a/prost-build/src/context.rs +++ b/prost-build/src/context.rs @@ -234,6 +234,51 @@ impl<'a> Context<'a> { } } + /// Returns `true` if this message can automatically derive Eq trait. + pub fn can_message_derive_eq(&self, fq_message_name: &str) -> bool { + assert_eq!(".", &fq_message_name[..1]); + + let msg = self.message_graph.get_message(fq_message_name).unwrap(); + msg.field + .iter() + .all(|field| self.can_field_derive_eq(fq_message_name, field)) + } + + /// Returns `true` if the type of this field allows deriving the Eq trait. + pub fn can_field_derive_eq(&self, fq_message_name: &str, field: &FieldDescriptorProto) -> bool { + assert_eq!(".", &fq_message_name[..1]); + + if field.r#type() == Type::Message { + if field.label() == Label::Repeated + || self + .message_graph + .is_nested(field.type_name(), fq_message_name) + { + false + } else { + self.can_message_derive_eq(field.type_name()) + } + } else { + matches!( + field.r#type(), + Type::Int32 + | Type::Int64 + | Type::Uint32 + | Type::Uint64 + | Type::Sint32 + | Type::Sint64 + | Type::Fixed32 + | Type::Fixed64 + | Type::Sfixed32 + | Type::Sfixed64 + | Type::Bool + | Type::Enum + | Type::String + | Type::Bytes + ) + } + } + pub fn should_disable_comments(&self, fq_message_name: &str, field_name: Option<&str>) -> bool { if let Some(field_name) = field_name { self.config diff --git a/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs b/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs index bf1e8c517..509e96bbe 100644 --- a/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs +++ b/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs @@ -1,12 +1,12 @@ // This file is @generated by prost-build. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Container { #[prost(oneof="container::Data", tags="1, 2")] pub data: ::core::option::Option, } /// Nested message and enum types in `Container`. pub mod container { - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum Data { #[prost(message, tag="1")] Foo(::prost::alloc::boxed::Box), @@ -14,16 +14,16 @@ pub mod container { Bar(super::Bar), } } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Foo { #[prost(string, tag="1")] pub foo: ::prost::alloc::string::String, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Bar { #[prost(message, optional, boxed, tag="1")] pub qux: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct Qux { } diff --git a/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs b/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs index c130aad2e..9f5b10cb1 100644 --- a/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs +++ b/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs @@ -1,12 +1,12 @@ // This file is @generated by prost-build. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Container { #[prost(oneof = "container::Data", tags = "1, 2")] pub data: ::core::option::Option, } /// Nested message and enum types in `Container`. pub mod container { - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum Data { #[prost(message, tag = "1")] Foo(::prost::alloc::boxed::Box), @@ -14,15 +14,15 @@ pub mod container { Bar(super::Bar), } } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Foo { #[prost(string, tag = "1")] pub foo: ::prost::alloc::string::String, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Bar { #[prost(message, optional, boxed, tag = "1")] pub qux: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct Qux {} diff --git a/prost-build/src/fixtures/helloworld/_expected_helloworld.rs b/prost-build/src/fixtures/helloworld/_expected_helloworld.rs index f39278358..ae65e24df 100644 --- a/prost-build/src/fixtures/helloworld/_expected_helloworld.rs +++ b/prost-build/src/fixtures/helloworld/_expected_helloworld.rs @@ -1,14 +1,14 @@ // This file is @generated by prost-build. #[derive(derive_builder::Builder)] #[derive(custom_proto::Input)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Message { #[prost(string, tag="1")] pub say: ::prost::alloc::string::String, } #[derive(derive_builder::Builder)] #[derive(custom_proto::Output)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Response { #[prost(string, tag="1")] pub say: ::prost::alloc::string::String, diff --git a/prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs b/prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs index c75338e2b..49a1f1f9e 100644 --- a/prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs +++ b/prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs @@ -1,14 +1,14 @@ // This file is @generated by prost-build. #[derive(derive_builder::Builder)] #[derive(custom_proto::Input)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Message { #[prost(string, tag = "1")] pub say: ::prost::alloc::string::String, } #[derive(derive_builder::Builder)] #[derive(custom_proto::Output)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Response { #[prost(string, tag = "1")] pub say: ::prost::alloc::string::String, diff --git a/prost-types/src/compiler.rs b/prost-types/src/compiler.rs index d274aeba8..89e2e253f 100644 --- a/prost-types/src/compiler.rs +++ b/prost-types/src/compiler.rs @@ -1,7 +1,7 @@ // This file is @generated by prost-build. /// The version number of protocol compiler. #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Version { #[prost(int32, optional, tag = "1")] pub major: ::core::option::Option, diff --git a/prost-types/src/duration.rs b/prost-types/src/duration.rs index 3ce993ee5..187f04a43 100644 --- a/prost-types/src/duration.rs +++ b/prost-types/src/duration.rs @@ -1,13 +1,5 @@ use super::*; -#[cfg(feature = "std")] -impl std::hash::Hash for Duration { - fn hash(&self, state: &mut H) { - self.seconds.hash(state); - self.nanos.hash(state); - } -} - impl Duration { /// Normalizes the duration to a canonical format. /// diff --git a/prost-types/src/protobuf.rs b/prost-types/src/protobuf.rs index b24264901..53e1d5df1 100644 --- a/prost-types/src/protobuf.rs +++ b/prost-types/src/protobuf.rs @@ -94,7 +94,7 @@ pub mod descriptor_proto { /// fields or extension ranges in the same message. Reserved ranges may /// not overlap. #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] - #[derive(Clone, Copy, PartialEq, ::prost::Message)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct ReservedRange { /// Inclusive. #[prost(int32, optional, tag = "1")] @@ -362,7 +362,7 @@ pub mod enum_descriptor_proto { /// is inclusive such that it can appropriately represent the entire int32 /// domain. #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] - #[derive(Clone, Copy, PartialEq, ::prost::Message)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct EnumReservedRange { /// Inclusive. #[prost(int32, optional, tag = "1")] @@ -990,7 +990,7 @@ pub mod uninterpreted_option { /// E.g.,{ \["foo", false\], \["bar.baz", true\], \["qux", false\] } represents /// "foo.(bar.baz).qux". #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] - #[derive(Clone, PartialEq, ::prost::Message)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct NamePart { #[prost(string, required, tag = "1")] pub name_part: ::prost::alloc::string::String, @@ -1053,7 +1053,7 @@ pub struct SourceCodeInfo { /// Nested message and enum types in `SourceCodeInfo`. pub mod source_code_info { #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] - #[derive(Clone, PartialEq, ::prost::Message)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Location { /// Identifies which part of the FileDescriptorProto was defined at this /// location. @@ -1158,7 +1158,7 @@ pub struct GeneratedCodeInfo { /// Nested message and enum types in `GeneratedCodeInfo`. pub mod generated_code_info { #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] - #[derive(Clone, PartialEq, ::prost::Message)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Annotation { /// Identifies the element in the original source .proto file. This field /// is formatted the same as SourceCodeInfo.Location.path. @@ -1272,7 +1272,7 @@ pub mod generated_code_info { /// } /// ``` #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Any { /// A URL/resource name that uniquely identifies the type of the serialized /// protocol buffer message. This string must contain at least @@ -1310,7 +1310,7 @@ pub struct Any { /// `SourceContext` represents information about the source of a /// protobuf element, like the file in which it is defined. #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct SourceContext { /// The path-qualified name of the .proto file that contained the associated /// protobuf element. For example: `"google/protobuf/source_context.proto"`. @@ -1573,7 +1573,7 @@ pub struct EnumValue { /// A protocol buffer option, which can be attached to a message, field, /// enumeration, etc. #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Option { /// The option's name. For protobuf built-in options (options defined in /// descriptor.proto), this is the short name. For example, `"map_entry"`. @@ -1787,7 +1787,7 @@ pub struct Method { /// } /// ``` #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Mixin { /// The fully qualified name of the interface which is included. #[prost(string, tag = "1")] @@ -1862,7 +1862,7 @@ pub struct Mixin { /// be expressed in JSON format as "3.000000001s", and 3 seconds and 1 /// microsecond should be expressed in JSON format as "3.000001s". #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct Duration { /// Signed seconds of the span of time. Must be from -315,576,000,000 /// to +315,576,000,000 inclusive. Note: these bounds are computed from: @@ -2101,7 +2101,7 @@ pub struct Duration { /// request should verify the included field paths, and return an /// `INVALID_ARGUMENT` error if any path is unmappable. #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct FieldMask { /// The set of field mask paths. #[prost(string, repeated, tag = "1")] @@ -2303,7 +2303,7 @@ impl NullValue { /// the time format spec '%Y-%m-%dT%H:%M:%S.%fZ'. Likewise, in Java, one can use /// the Joda Time's [`ISODateTimeFormat.dateTime()`]() to obtain a formatter capable of generating timestamps in this format. #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct Timestamp { /// Represents seconds of UTC time since Unix epoch /// 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to diff --git a/prost-types/src/timestamp.rs b/prost-types/src/timestamp.rs index 9ed0db6ea..69be3763a 100644 --- a/prost-types/src/timestamp.rs +++ b/prost-types/src/timestamp.rs @@ -123,19 +123,6 @@ impl Name for Timestamp { } } -/// Implements the unstable/naive version of `Eq`: a basic equality check on the internal fields of the `Timestamp`. -/// This implies that `normalized_ts != non_normalized_ts` even if `normalized_ts == non_normalized_ts.normalized()`. -#[cfg(feature = "std")] -impl Eq for Timestamp {} - -#[cfg(feature = "std")] -impl std::hash::Hash for Timestamp { - fn hash(&self, state: &mut H) { - self.seconds.hash(state); - self.nanos.hash(state); - } -} - #[cfg(feature = "std")] impl From for Timestamp { fn from(system_time: std::time::SystemTime) -> Timestamp { diff --git a/tests/build.rs b/tests/build.rs index 446c8cf0e..4bd82f4e8 100644 --- a/tests/build.rs +++ b/tests/build.rs @@ -34,7 +34,7 @@ fn main() { ); config.type_attribute( "Foo.Custom.OneOfAttrs.Msg.field", - "#[derive(Eq, PartialOrd, Ord)]", + "#[derive(PartialOrd, Ord)]", ); config.file_descriptor_set_path( diff --git a/tests/single-include/src/outdir/outdir.rs b/tests/single-include/src/outdir/outdir.rs index 233028a04..c285a3875 100644 --- a/tests/single-include/src/outdir/outdir.rs +++ b/tests/single-include/src/outdir/outdir.rs @@ -1,5 +1,5 @@ // This file is @generated by prost-build. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct OutdirRequest { #[prost(string, tag = "1")] pub query: ::prost::alloc::string::String,