Skip to content

Commit df35e21

Browse files
damelLPryankurte
authored andcommitted
Add message and enum attributes to prost-build (tokio-rs#784)
closes tokio-rs#783
1 parent 4726f93 commit df35e21

File tree

4 files changed

+194
-0
lines changed

4 files changed

+194
-0
lines changed

prost-build/src/code_generator.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ impl<'a> CodeGenerator<'a> {
180180

181181
self.append_doc(&fq_message_name, None);
182182
self.append_type_attributes(&fq_message_name);
183+
self.append_message_attributes(&fq_message_name);
183184
self.push_indent();
184185
self.buf
185186
.push_str("#[allow(clippy::derive_partial_eq_without_eq)]\n");
@@ -270,6 +271,24 @@ impl<'a> CodeGenerator<'a> {
270271
}
271272
}
272273

274+
fn append_message_attributes(&mut self, fq_message_name: &str) {
275+
assert_eq!(b'.', fq_message_name.as_bytes()[0]);
276+
for attribute in self.config.message_attributes.get(fq_message_name) {
277+
push_indent(self.buf, self.depth);
278+
self.buf.push_str(attribute);
279+
self.buf.push('\n');
280+
}
281+
}
282+
283+
fn append_enum_attributes(&mut self, fq_message_name: &str) {
284+
assert_eq!(b'.', fq_message_name.as_bytes()[0]);
285+
for attribute in self.config.enum_attributes.get(fq_message_name) {
286+
push_indent(self.buf, self.depth);
287+
self.buf.push_str(attribute);
288+
self.buf.push('\n');
289+
}
290+
}
291+
273292
fn append_field_attributes(&mut self, fq_message_name: &str, field_name: &str) {
274293
assert_eq!(b'.', fq_message_name.as_bytes()[0]);
275294
for attribute in self
@@ -504,6 +523,7 @@ impl<'a> CodeGenerator<'a> {
504523

505524
let oneof_name = format!("{}.{}", fq_message_name, oneof.name());
506525
self.append_type_attributes(&oneof_name);
526+
self.append_enum_attributes(&oneof_name);
507527
self.push_indent();
508528
self.buf
509529
.push_str("#[allow(clippy::derive_partial_eq_without_eq)]\n");
@@ -615,6 +635,7 @@ impl<'a> CodeGenerator<'a> {
615635

616636
self.append_doc(&fq_proto_enum_name, None);
617637
self.append_type_attributes(&fq_proto_enum_name);
638+
self.append_enum_attributes(&fq_proto_enum_name);
618639
self.push_indent();
619640
self.buf.push_str(
620641
&format!("#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, {}::Enumeration)]\n",self.config.prost_path.as_deref().unwrap_or("::prost")),
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#[derive(derive_builder::Builder)]
2+
#[allow(clippy::derive_partial_eq_without_eq)]
3+
#[derive(Clone, PartialEq, ::prost::Message)]
4+
pub struct Message {
5+
#[prost(string, tag = "1")]
6+
pub say: ::prost::alloc::string::String,
7+
}
8+
#[derive(derive_builder::Builder)]
9+
#[allow(clippy::derive_partial_eq_without_eq)]
10+
#[derive(Clone, PartialEq, ::prost::Message)]
11+
pub struct Response {
12+
#[prost(string, tag = "1")]
13+
pub say: ::prost::alloc::string::String,
14+
}
15+
#[some_enum_attr(u8)]
16+
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
17+
#[repr(i32)]
18+
pub enum ServingStatus {
19+
Unknown = 0,
20+
Serving = 1,
21+
NotServing = 2,
22+
}
23+
impl ServingStatus {
24+
/// String value of the enum field names used in the ProtoBuf definition.
25+
///
26+
/// The values are not transformed in any way and thus are considered stable
27+
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
28+
pub fn as_str_name(&self) -> &'static str {
29+
match self {
30+
ServingStatus::Unknown => "UNKNOWN",
31+
ServingStatus::Serving => "SERVING",
32+
ServingStatus::NotServing => "NOT_SERVING",
33+
}
34+
}
35+
/// Creates an enum from field names used in the ProtoBuf definition.
36+
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
37+
match value {
38+
"UNKNOWN" => Some(Self::Unknown),
39+
"SERVING" => Some(Self::Serving),
40+
"NOT_SERVING" => Some(Self::NotServing),
41+
_ => None,
42+
}
43+
}
44+
}

prost-build/src/fixtures/helloworld/types.proto

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,9 @@ message Message {
99
message Response {
1010
string say = 1;
1111
}
12+
13+
enum ServingStatus {
14+
UNKNOWN = 0;
15+
SERVING = 1;
16+
NOT_SERVING = 2;
17+
}

prost-build/src/lib.rs

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,8 @@ pub struct Config {
244244
map_type: PathMap<MapType>,
245245
bytes_type: PathMap<BytesType>,
246246
type_attributes: PathMap<String>,
247+
message_attributes: PathMap<String>,
248+
enum_attributes: PathMap<String>,
247249
field_attributes: PathMap<String>,
248250
prost_types: bool,
249251
strip_enum_prefix: bool,
@@ -468,6 +470,94 @@ impl Config {
468470
self
469471
}
470472

473+
/// Add additional attribute to matched messages.
474+
///
475+
/// # Arguments
476+
///
477+
/// **`paths`** - a path matching any number of types. It works the same way as in
478+
/// [`btree_map`](#method.btree_map), just with the field name omitted.
479+
///
480+
/// **`attribute`** - an arbitrary string to be placed before each matched type. The
481+
/// expected usage are additional attributes, but anything is allowed.
482+
///
483+
/// The calls to this method are cumulative. They don't overwrite previous calls and if a
484+
/// type is matched by multiple calls of the method, all relevant attributes are added to
485+
/// it.
486+
///
487+
/// For things like serde it might be needed to combine with [field
488+
/// attributes](#method.field_attribute).
489+
///
490+
/// # Examples
491+
///
492+
/// ```rust
493+
/// # let mut config = prost_build::Config::new();
494+
/// // Nothing around uses floats, so we can derive real `Eq` in addition to `PartialEq`.
495+
/// config.message_attribute(".", "#[derive(Eq)]");
496+
/// // Some messages want to be serializable with serde as well.
497+
/// config.message_attribute("my_messages.MyMessageType",
498+
/// "#[derive(Serialize)] #[serde(rename_all = \"snake_case\")]");
499+
/// config.message_attribute("my_messages.MyMessageType.MyNestedMessageType",
500+
/// "#[derive(Serialize)] #[serde(rename_all = \"snake_case\")]");
501+
/// ```
502+
pub fn message_attribute<P, A>(&mut self, path: P, attribute: A) -> &mut Self
503+
where
504+
P: AsRef<str>,
505+
A: AsRef<str>,
506+
{
507+
self.message_attributes
508+
.insert(path.as_ref().to_string(), attribute.as_ref().to_string());
509+
self
510+
}
511+
512+
/// Add additional attribute to matched enums and one-ofs.
513+
///
514+
/// # Arguments
515+
///
516+
/// **`paths`** - a path matching any number of types. It works the same way as in
517+
/// [`btree_map`](#method.btree_map), just with the field name omitted.
518+
///
519+
/// **`attribute`** - an arbitrary string to be placed before each matched type. The
520+
/// expected usage are additional attributes, but anything is allowed.
521+
///
522+
/// The calls to this method are cumulative. They don't overwrite previous calls and if a
523+
/// type is matched by multiple calls of the method, all relevant attributes are added to
524+
/// it.
525+
///
526+
/// For things like serde it might be needed to combine with [field
527+
/// attributes](#method.field_attribute).
528+
///
529+
/// # Examples
530+
///
531+
/// ```rust
532+
/// # let mut config = prost_build::Config::new();
533+
/// // Nothing around uses floats, so we can derive real `Eq` in addition to `PartialEq`.
534+
/// config.enum_attribute(".", "#[derive(Eq)]");
535+
/// // Some messages want to be serializable with serde as well.
536+
/// config.enum_attribute("my_messages.MyEnumType",
537+
/// "#[derive(Serialize)] #[serde(rename_all = \"snake_case\")]");
538+
/// config.enum_attribute("my_messages.MyMessageType.MyNestedEnumType",
539+
/// "#[derive(Serialize)] #[serde(rename_all = \"snake_case\")]");
540+
/// ```
541+
///
542+
/// # Oneof fields
543+
///
544+
/// The `oneof` fields don't have a type name of their own inside Protobuf. Therefore, the
545+
/// field name can be used both with `enum_attribute` and `field_attribute` ‒ the first is
546+
/// placed before the `enum` type definition, the other before the field inside corresponding
547+
/// message `struct`.
548+
///
549+
/// In other words, to place an attribute on the `enum` implementing the `oneof`, the match
550+
/// would look like `my_messages.MyNestedMessageType.oneofname`.
551+
pub fn enum_attribute<P, A>(&mut self, path: P, attribute: A) -> &mut Self
552+
where
553+
P: AsRef<str>,
554+
A: AsRef<str>,
555+
{
556+
self.enum_attributes
557+
.insert(path.as_ref().to_string(), attribute.as_ref().to_string());
558+
self
559+
}
560+
471561
/// Configures the code generator to use the provided service generator.
472562
pub fn service_generator(&mut self, service_generator: Box<dyn ServiceGenerator>) -> &mut Self {
473563
self.service_generator = Some(service_generator);
@@ -1099,6 +1189,8 @@ impl default::Default for Config {
10991189
map_type: PathMap::default(),
11001190
bytes_type: PathMap::default(),
11011191
type_attributes: PathMap::default(),
1192+
message_attributes: PathMap::default(),
1193+
enum_attributes: PathMap::default(),
11021194
field_attributes: PathMap::default(),
11031195
prost_types: true,
11041196
strip_enum_prefix: true,
@@ -1425,6 +1517,37 @@ mod tests {
14251517
assert_eq!(state.finalized, 3);
14261518
}
14271519

1520+
#[test]
1521+
fn test_generate_message_attributes() {
1522+
let _ = env_logger::try_init();
1523+
1524+
let out_dir = std::env::temp_dir();
1525+
1526+
Config::new()
1527+
.out_dir(out_dir.clone())
1528+
.message_attribute(".", "#[derive(derive_builder::Builder)]")
1529+
.enum_attribute(".", "#[some_enum_attr(u8)]")
1530+
.compile_protos(
1531+
&["src/fixtures/helloworld/hello.proto"],
1532+
&["src/fixtures/helloworld"],
1533+
)
1534+
.unwrap();
1535+
1536+
let out_file = out_dir
1537+
.join("helloworld.rs")
1538+
.as_path()
1539+
.display()
1540+
.to_string();
1541+
let expected_content = read_all_content("src/fixtures/helloworld/_expected_helloworld.rs")
1542+
.replace("\r\n", "\n");
1543+
let content = read_all_content(&out_file).replace("\r\n", "\n");
1544+
assert_eq!(
1545+
expected_content, content,
1546+
"Unexpected content: \n{}",
1547+
content
1548+
);
1549+
}
1550+
14281551
#[test]
14291552
fn test_generate_no_empty_outputs() {
14301553
let _ = env_logger::try_init();

0 commit comments

Comments
 (0)