Skip to content

Commit baddf98

Browse files
authored
fix: include_file should handle proto without package (#1002)
* fix #1001 and add tests * add alloc:: imports * rewrite write_includes to allow for empty modules. * create test fixture for `write_includes` * fix lints, remove line feeds * fixes after merge master * remove some duplicate tests and alter existing ones to test write_includes * more test * module.rs Module::starts_with visibility
1 parent 1f38ea6 commit baddf98

File tree

6 files changed

+138
-58
lines changed

6 files changed

+138
-58
lines changed

prost-build/src/config.rs

Lines changed: 44 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -826,8 +826,8 @@ impl Config {
826826
self.write_includes(
827827
modules.keys().collect(),
828828
&mut file,
829-
0,
830829
if target_is_env { None } else { Some(&target) },
830+
&file_names,
831831
)?;
832832
file.flush()?;
833833
}
@@ -955,67 +955,58 @@ impl Config {
955955
self.compile_fds(file_descriptor_set)
956956
}
957957

958-
fn write_includes(
958+
pub(crate) fn write_includes(
959959
&self,
960-
mut entries: Vec<&Module>,
961-
outfile: &mut fs::File,
962-
depth: usize,
960+
mut modules: Vec<&Module>,
961+
outfile: &mut impl Write,
963962
basepath: Option<&PathBuf>,
964-
) -> Result<usize> {
965-
let mut written = 0;
966-
entries.sort();
967-
968-
while !entries.is_empty() {
969-
let modident = entries[0].part(depth);
970-
let matching: Vec<&Module> = entries
971-
.iter()
972-
.filter(|&v| v.part(depth) == modident)
973-
.copied()
974-
.collect();
975-
{
976-
// Will NLL sort this mess out?
977-
let _temp = entries
978-
.drain(..)
979-
.filter(|&v| v.part(depth) != modident)
980-
.collect();
981-
entries = _temp;
963+
file_names: &HashMap<Module, String>,
964+
) -> Result<()> {
965+
modules.sort();
966+
967+
let mut stack = Vec::new();
968+
969+
for module in modules {
970+
while !module.starts_with(&stack) {
971+
stack.pop();
972+
self.write_line(outfile, stack.len(), "}")?;
982973
}
983-
self.write_line(outfile, depth, &format!("pub mod {} {{", modident))?;
984-
let subwritten = self.write_includes(
985-
matching
986-
.iter()
987-
.filter(|v| v.len() > depth + 1)
988-
.copied()
989-
.collect(),
990-
outfile,
991-
depth + 1,
992-
basepath,
993-
)?;
994-
written += subwritten;
995-
if subwritten != matching.len() {
996-
let modname = matching[0].to_partial_file_name(..=depth);
997-
if basepath.is_some() {
998-
self.write_line(
999-
outfile,
1000-
depth + 1,
1001-
&format!("include!(\"{}.rs\");", modname),
1002-
)?;
1003-
} else {
1004-
self.write_line(
1005-
outfile,
1006-
depth + 1,
1007-
&format!("include!(concat!(env!(\"OUT_DIR\"), \"/{}.rs\"));", modname),
1008-
)?;
1009-
}
1010-
written += 1;
974+
while stack.len() < module.len() {
975+
self.write_line(
976+
outfile,
977+
stack.len(),
978+
&format!("pub mod {} {{", module.part(stack.len())),
979+
)?;
980+
stack.push(module.part(stack.len()).to_owned());
1011981
}
1012982

983+
let file_name = file_names
984+
.get(module)
985+
.expect("every module should have a filename");
986+
987+
if basepath.is_some() {
988+
self.write_line(
989+
outfile,
990+
stack.len(),
991+
&format!("include!(\"{}\");", file_name),
992+
)?;
993+
} else {
994+
self.write_line(
995+
outfile,
996+
stack.len(),
997+
&format!("include!(concat!(env!(\"OUT_DIR\"), \"/{}\"));", file_name),
998+
)?;
999+
}
1000+
}
1001+
1002+
for depth in (0..stack.len()).rev() {
10131003
self.write_line(outfile, depth, "}")?;
10141004
}
1015-
Ok(written)
1005+
1006+
Ok(())
10161007
}
10171008

1018-
fn write_line(&self, outfile: &mut fs::File, depth: usize, line: &str) -> Result<()> {
1009+
fn write_line(&self, outfile: &mut impl Write, depth: usize, line: &str) -> Result<()> {
10191010
outfile.write_all(format!("{}{}\n", (" ").to_owned().repeat(depth), line).as_bytes())
10201011
}
10211012

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
include!(concat!(env!("OUT_DIR"), "/_.default.rs"));
2+
pub mod bar {
3+
include!(concat!(env!("OUT_DIR"), "/bar.rs"));
4+
}
5+
pub mod foo {
6+
include!(concat!(env!("OUT_DIR"), "/foo.rs"));
7+
pub mod bar {
8+
include!(concat!(env!("OUT_DIR"), "/foo.bar.rs"));
9+
pub mod a {
10+
pub mod b {
11+
pub mod c {
12+
include!(concat!(env!("OUT_DIR"), "/foo.bar.a.b.c.rs"));
13+
}
14+
}
15+
}
16+
pub mod baz {
17+
include!(concat!(env!("OUT_DIR"), "/foo.bar.baz.rs"));
18+
}
19+
pub mod qux {
20+
include!(concat!(env!("OUT_DIR"), "/foo.bar.qux.rs"));
21+
}
22+
}
23+
}

prost-build/src/lib.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,4 +530,32 @@ mod tests {
530530
f.read_to_string(&mut content).unwrap();
531531
content
532532
}
533+
534+
#[test]
535+
fn write_includes() {
536+
let modules = [
537+
Module::from_protobuf_package_name("foo.bar.baz"),
538+
Module::from_protobuf_package_name(""),
539+
Module::from_protobuf_package_name("foo.bar"),
540+
Module::from_protobuf_package_name("bar"),
541+
Module::from_protobuf_package_name("foo"),
542+
Module::from_protobuf_package_name("foo.bar.qux"),
543+
Module::from_protobuf_package_name("foo.bar.a.b.c"),
544+
];
545+
546+
let file_names = modules
547+
.iter()
548+
.map(|m| (m.clone(), m.to_file_name_or("_.default")))
549+
.collect();
550+
551+
let mut buf = Vec::new();
552+
Config::new()
553+
.default_package_filename("_.default")
554+
.write_includes(modules.iter().collect(), &mut buf, None, &file_names)
555+
.unwrap();
556+
let expected =
557+
read_all_content("src/fixtures/write_includes/_.includes.rs").replace("\r\n", "\n");
558+
let actual = String::from_utf8(buf).unwrap();
559+
assert_eq!(expected, actual);
560+
}
533561
}

prost-build/src/module.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use std::fmt;
2-
use std::ops::RangeToInclusive;
32

43
use crate::ident::to_snake;
54

@@ -40,6 +39,15 @@ impl Module {
4039
self.components.iter().map(|s| s.as_str())
4140
}
4241

42+
#[must_use]
43+
#[inline(always)]
44+
pub(crate) fn starts_with(&self, needle: &[String]) -> bool
45+
where
46+
String: PartialEq,
47+
{
48+
self.components.starts_with(needle)
49+
}
50+
4351
/// Format the module path into a filename for generated Rust code.
4452
///
4553
/// If the module path is empty, `default` is used to provide the root of the filename.
@@ -65,10 +73,6 @@ impl Module {
6573
self.components.is_empty()
6674
}
6775

68-
pub(crate) fn to_partial_file_name(&self, range: RangeToInclusive<usize>) -> String {
69-
self.components[range].join(".")
70-
}
71-
7276
pub(crate) fn part(&self, idx: usize) -> &str {
7377
self.components[idx].as_str()
7478
}

tests/src/build.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ fn main() {
178178
no_root_packages_config
179179
.out_dir(&no_root_packages)
180180
.default_package_filename("__.default")
181+
.include_file("__.include.rs")
181182
.compile_protos(
182183
&[src.join("no_root_packages/widget_factory.proto")],
183184
&[src.join("no_root_packages")],

tests/src/no_root_packages/mod.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ pub mod widget {
1616
}
1717
}
1818

19+
pub mod generated_include {
20+
include!(concat!(env!("OUT_DIR"), "/no_root_packages/__.include.rs"));
21+
}
22+
1923
#[test]
2024
fn test() {
2125
use prost::Message;
@@ -44,3 +48,32 @@ fn test() {
4448
widget_factory.gizmo_inner = Some(gizmo::gizmo::Inner {});
4549
assert_eq!(14, widget_factory.encoded_len());
4650
}
51+
52+
#[test]
53+
fn generated_include() {
54+
use prost::Message;
55+
56+
let mut widget_factory = generated_include::widget::factory::WidgetFactory::default();
57+
assert_eq!(0, widget_factory.encoded_len());
58+
59+
widget_factory.inner = Some(generated_include::widget::factory::widget_factory::Inner {});
60+
assert_eq!(2, widget_factory.encoded_len());
61+
62+
widget_factory.root = Some(generated_include::Root {});
63+
assert_eq!(4, widget_factory.encoded_len());
64+
65+
widget_factory.root_inner = Some(generated_include::root::Inner {});
66+
assert_eq!(6, widget_factory.encoded_len());
67+
68+
widget_factory.widget = Some(generated_include::widget::Widget {});
69+
assert_eq!(8, widget_factory.encoded_len());
70+
71+
widget_factory.widget_inner = Some(generated_include::widget::widget::Inner {});
72+
assert_eq!(10, widget_factory.encoded_len());
73+
74+
widget_factory.gizmo = Some(generated_include::gizmo::Gizmo {});
75+
assert_eq!(12, widget_factory.encoded_len());
76+
77+
widget_factory.gizmo_inner = Some(generated_include::gizmo::gizmo::Inner {});
78+
assert_eq!(14, widget_factory.encoded_len());
79+
}

0 commit comments

Comments
 (0)