Skip to content

Commit 85c698a

Browse files
authored
feat: derive Copy trait for messages where possible (#950)
* feat: derive Copy trait for messages where possible Rust primitive types can be copied by simply copying the bits. Rust structs can also have this property by deriving the Copy trait. Automatically derive Copy for: - messages that only have fields with primitive types - the Rust enum for one-of fields - messages whose field type are messages that also implement Copy Generated code for Protobuf enums already derives Copy. * fix: Remove clone call when copy is implemented Clippy reports: warning: using `clone` on type `Timestamp` which implements the `Copy` trait
1 parent d42c85e commit 85c698a

File tree

12 files changed

+163
-24
lines changed

12 files changed

+163
-24
lines changed

prost-build/src/code_generator.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,12 @@ impl<'a> CodeGenerator<'a> {
231231
self.buf
232232
.push_str("#[allow(clippy::derive_partial_eq_without_eq)]\n");
233233
self.buf.push_str(&format!(
234-
"#[derive(Clone, PartialEq, {}::Message)]\n",
234+
"#[derive(Clone, {}PartialEq, {}::Message)]\n",
235+
if self.message_graph.can_message_derive_copy(&fq_message_name) {
236+
"Copy, "
237+
} else {
238+
""
239+
},
235240
prost_path(self.config)
236241
));
237242
self.append_skip_debug(&fq_message_name);
@@ -613,8 +618,14 @@ impl<'a> CodeGenerator<'a> {
613618
self.push_indent();
614619
self.buf
615620
.push_str("#[allow(clippy::derive_partial_eq_without_eq)]\n");
621+
622+
let can_oneof_derive_copy = fields.iter().map(|(field, _idx)| field).all(|field| {
623+
self.message_graph
624+
.can_field_derive_copy(fq_message_name, field)
625+
});
616626
self.buf.push_str(&format!(
617-
"#[derive(Clone, PartialEq, {}::Oneof)]\n",
627+
"#[derive(Clone, {}PartialEq, {}::Oneof)]\n",
628+
if can_oneof_derive_copy { "Copy, " } else { "" },
618629
prost_path(self.config)
619630
));
620631
self.append_skip_debug(fq_message_name);

prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ pub struct Foo {
2323
pub foo: ::prost::alloc::string::String,
2424
}
2525
#[allow(clippy::derive_partial_eq_without_eq)]
26-
#[derive(Clone, PartialEq, ::prost::Message)]
26+
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
2727
pub struct Bar {
2828
#[prost(message, optional, boxed, tag="1")]
2929
pub qux: ::core::option::Option<::prost::alloc::boxed::Box<Qux>>,
3030
}
3131
#[allow(clippy::derive_partial_eq_without_eq)]
32-
#[derive(Clone, PartialEq, ::prost::Message)]
32+
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
3333
pub struct Qux {
3434
}

prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ pub struct Foo {
2323
pub foo: ::prost::alloc::string::String,
2424
}
2525
#[allow(clippy::derive_partial_eq_without_eq)]
26-
#[derive(Clone, PartialEq, ::prost::Message)]
26+
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
2727
pub struct Bar {
2828
#[prost(message, optional, boxed, tag = "1")]
2929
pub qux: ::core::option::Option<::prost::alloc::boxed::Box<Qux>>,
3030
}
3131
#[allow(clippy::derive_partial_eq_without_eq)]
32-
#[derive(Clone, PartialEq, ::prost::Message)]
32+
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
3333
pub struct Qux {}

prost-build/src/message_graph.rs

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@ use petgraph::algo::has_path_connecting;
44
use petgraph::graph::NodeIndex;
55
use petgraph::Graph;
66

7-
use prost_types::{field_descriptor_proto, DescriptorProto, FileDescriptorProto};
7+
use prost_types::{
8+
field_descriptor_proto::{Label, Type},
9+
DescriptorProto, FieldDescriptorProto, FileDescriptorProto,
10+
};
811

912
/// `MessageGraph` builds a graph of messages whose edges correspond to nesting.
1013
/// The goal is to recognize when message types are recursively nested, so
1114
/// that fields can be boxed when necessary.
1215
pub struct MessageGraph {
1316
index: HashMap<String, NodeIndex>,
1417
graph: Graph<String, ()>,
18+
messages: HashMap<String, DescriptorProto>,
1519
}
1620

1721
impl MessageGraph {
@@ -21,6 +25,7 @@ impl MessageGraph {
2125
let mut msg_graph = MessageGraph {
2226
index: HashMap::new(),
2327
graph: Graph::new(),
28+
messages: HashMap::new(),
2429
};
2530

2631
for file in files {
@@ -41,6 +46,7 @@ impl MessageGraph {
4146
let MessageGraph {
4247
ref mut index,
4348
ref mut graph,
49+
..
4450
} = *self;
4551
assert_eq!(b'.', msg_name.as_bytes()[0]);
4652
*index
@@ -58,13 +64,12 @@ impl MessageGraph {
5864
let msg_index = self.get_or_insert_index(msg_name.clone());
5965

6066
for field in &msg.field {
61-
if field.r#type() == field_descriptor_proto::Type::Message
62-
&& field.label() != field_descriptor_proto::Label::Repeated
63-
{
67+
if field.r#type() == Type::Message && field.label() != Label::Repeated {
6468
let field_index = self.get_or_insert_index(field.type_name.clone().unwrap());
6569
self.graph.add_edge(msg_index, field_index, ());
6670
}
6771
}
72+
self.messages.insert(msg_name.clone(), msg.clone());
6873

6974
for msg in &msg.nested_type {
7075
self.add_message(&msg_name, msg);
@@ -84,4 +89,50 @@ impl MessageGraph {
8489

8590
has_path_connecting(&self.graph, outer, inner, None)
8691
}
92+
93+
/// Returns `true` if this message can automatically derive Copy trait.
94+
pub fn can_message_derive_copy(&self, fq_message_name: &str) -> bool {
95+
assert_eq!(".", &fq_message_name[..1]);
96+
let msg = self.messages.get(fq_message_name).unwrap();
97+
msg.field
98+
.iter()
99+
.all(|field| self.can_field_derive_copy(fq_message_name, field))
100+
}
101+
102+
/// Returns `true` if the type of this field allows deriving the Copy trait.
103+
pub fn can_field_derive_copy(
104+
&self,
105+
fq_message_name: &str,
106+
field: &FieldDescriptorProto,
107+
) -> bool {
108+
assert_eq!(".", &fq_message_name[..1]);
109+
110+
if field.label() == Label::Repeated {
111+
false
112+
} else if field.r#type() == Type::Message {
113+
if self.is_nested(field.type_name(), fq_message_name) {
114+
false
115+
} else {
116+
self.can_message_derive_copy(field.type_name())
117+
}
118+
} else {
119+
matches!(
120+
field.r#type(),
121+
Type::Float
122+
| Type::Double
123+
| Type::Int32
124+
| Type::Int64
125+
| Type::Uint32
126+
| Type::Uint64
127+
| Type::Sint32
128+
| Type::Sint64
129+
| Type::Fixed32
130+
| Type::Fixed64
131+
| Type::Sfixed32
132+
| Type::Sfixed64
133+
| Type::Bool
134+
| Type::Enum
135+
)
136+
}
137+
}
87138
}

prost-types/src/datetime.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ mod tests {
614614
};
615615
assert_eq!(
616616
expected,
617-
format!("{}", DateTime::from(timestamp.clone())),
617+
format!("{}", DateTime::from(timestamp)),
618618
"timestamp: {:?}",
619619
timestamp
620620
);

prost-types/src/duration.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ impl TryFrom<Duration> for time::Duration {
105105

106106
impl fmt::Display for Duration {
107107
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108-
let mut d = self.clone();
108+
let mut d = *self;
109109
d.normalize();
110110
if self.seconds < 0 && self.nanos < 0 {
111111
write!(f, "-")?;
@@ -193,7 +193,7 @@ mod tests {
193193
Ok(duration) => duration,
194194
Err(_) => return Err(TestCaseError::reject("duration out of range")),
195195
};
196-
prop_assert_eq!(time::Duration::try_from(prost_duration.clone()).unwrap(), std_duration);
196+
prop_assert_eq!(time::Duration::try_from(prost_duration).unwrap(), std_duration);
197197

198198
if std_duration != time::Duration::default() {
199199
let neg_prost_duration = Duration {
@@ -220,7 +220,7 @@ mod tests {
220220
Ok(duration) => duration,
221221
Err(_) => return Err(TestCaseError::reject("duration out of range")),
222222
};
223-
prop_assert_eq!(time::Duration::try_from(prost_duration.clone()).unwrap(), std_duration);
223+
prop_assert_eq!(time::Duration::try_from(prost_duration).unwrap(), std_duration);
224224

225225
if std_duration != time::Duration::default() {
226226
let neg_prost_duration = Duration {

prost-types/src/protobuf.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ pub mod descriptor_proto {
9494
/// fields or extension ranges in the same message. Reserved ranges may
9595
/// not overlap.
9696
#[allow(clippy::derive_partial_eq_without_eq)]
97-
#[derive(Clone, PartialEq, ::prost::Message)]
97+
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
9898
pub struct ReservedRange {
9999
/// Inclusive.
100100
#[prost(int32, optional, tag = "1")]
@@ -360,7 +360,7 @@ pub mod enum_descriptor_proto {
360360
/// is inclusive such that it can appropriately represent the entire int32
361361
/// domain.
362362
#[allow(clippy::derive_partial_eq_without_eq)]
363-
#[derive(Clone, PartialEq, ::prost::Message)]
363+
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
364364
pub struct EnumReservedRange {
365365
/// Inclusive.
366366
#[prost(int32, optional, tag = "1")]
@@ -1853,7 +1853,7 @@ pub struct Mixin {
18531853
/// be expressed in JSON format as "3.000000001s", and 3 seconds and 1
18541854
/// microsecond should be expressed in JSON format as "3.000001s".
18551855
#[allow(clippy::derive_partial_eq_without_eq)]
1856-
#[derive(Clone, PartialEq, ::prost::Message)]
1856+
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
18571857
pub struct Duration {
18581858
/// Signed seconds of the span of time. Must be from -315,576,000,000
18591859
/// to +315,576,000,000 inclusive. Note: these bounds are computed from:
@@ -2293,7 +2293,7 @@ impl NullValue {
22932293
/// the time format spec '%Y-%m-%dT%H:%M:%S.%fZ'. Likewise, in Java, one can use
22942294
/// the Joda Time's [`ISODateTimeFormat.dateTime()`](<http://www.joda.org/joda-time/apidocs/org/joda/time/format/ISODateTimeFormat.html#dateTime%2D%2D>) to obtain a formatter capable of generating timestamps in this format.
22952295
#[allow(clippy::derive_partial_eq_without_eq)]
2296-
#[derive(Clone, PartialEq, ::prost::Message)]
2296+
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
22972297
pub struct Timestamp {
22982298
/// Represents seconds of UTC time since Unix epoch
22992299
/// 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to

prost-types/src/timestamp.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ impl Timestamp {
5050
///
5151
/// [1]: https://github.com/google/protobuf/blob/v3.3.2/src/google/protobuf/util/time_util.cc#L59-L77
5252
pub fn try_normalize(mut self) -> Result<Timestamp, Timestamp> {
53-
let before = self.clone();
53+
let before = self;
5454
self.normalize();
5555
// If the seconds value has changed, and is either i64::MIN or i64::MAX, then the timestamp
5656
// normalization overflowed.
@@ -201,7 +201,7 @@ impl TryFrom<Timestamp> for std::time::SystemTime {
201201
type Error = TimestampError;
202202

203203
fn try_from(mut timestamp: Timestamp) -> Result<std::time::SystemTime, Self::Error> {
204-
let orig_timestamp = timestamp.clone();
204+
let orig_timestamp = timestamp;
205205
timestamp.normalize();
206206

207207
let system_time = if timestamp.seconds >= 0 {
@@ -211,8 +211,7 @@ impl TryFrom<Timestamp> for std::time::SystemTime {
211211
timestamp
212212
.seconds
213213
.checked_neg()
214-
.ok_or_else(|| TimestampError::OutOfSystemRange(timestamp.clone()))?
215-
as u64,
214+
.ok_or(TimestampError::OutOfSystemRange(timestamp))? as u64,
216215
))
217216
};
218217

@@ -234,7 +233,7 @@ impl FromStr for Timestamp {
234233

235234
impl fmt::Display for Timestamp {
236235
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
237-
datetime::DateTime::from(self.clone()).fmt(f)
236+
datetime::DateTime::from(*self).fmt(f)
238237
}
239238
}
240239
#[cfg(test)]
@@ -262,7 +261,7 @@ mod tests {
262261
) {
263262
let mut timestamp = Timestamp { seconds, nanos };
264263
timestamp.normalize();
265-
if let Ok(system_time) = SystemTime::try_from(timestamp.clone()) {
264+
if let Ok(system_time) = SystemTime::try_from(timestamp) {
266265
prop_assert_eq!(Timestamp::from(system_time), timestamp);
267266
}
268267
}

tests/src/build.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ fn main() {
9191
.compile_protos(&[src.join("deprecated_field.proto")], includes)
9292
.unwrap();
9393

94+
config
95+
.compile_protos(&[src.join("derive_copy.proto")], includes)
96+
.unwrap();
97+
9498
config
9599
.compile_protos(&[src.join("default_string_escape.proto")], includes)
96100
.unwrap();

tests/src/derive_copy.proto

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
syntax = "proto3";
2+
3+
import "google/protobuf/timestamp.proto";
4+
5+
package derive_copy;
6+
7+
message EmptyMsg {}
8+
9+
message IntegerMsg {
10+
int32 field1 = 1;
11+
int64 field2 = 2;
12+
uint32 field3 = 3;
13+
uint64 field4 = 4;
14+
sint32 field5 = 5;
15+
sint64 field6 = 6;
16+
fixed32 field7 = 7;
17+
fixed64 field8 = 8;
18+
sfixed32 field9 = 9;
19+
sfixed64 field10 = 10;
20+
}
21+
22+
message FloatMsg {
23+
double field1 = 1;
24+
float field2 = 2;
25+
}
26+
27+
message BoolMsg { bool field1 = 1; }
28+
29+
enum AnEnum {
30+
A = 0;
31+
B = 1;
32+
};
33+
34+
message EnumMsg { AnEnum field1 = 1; }
35+
36+
message OneOfMsg {
37+
oneof data {
38+
int32 field1 = 1;
39+
int64 field2 = 2;
40+
}
41+
}
42+
43+
message ComposedMsg {
44+
IntegerMsg field1 = 1;
45+
EnumMsg field2 = 2;
46+
OneOfMsg field3 = 3;
47+
}
48+
49+
message WellKnownMsg {
50+
google.protobuf.Timestamp timestamp = 1;
51+
}

0 commit comments

Comments
 (0)