Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion bazel/py_proto_library.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ _PyProtoInfo = provider(
(depset[File]) Files from the transitive closure implicit proto
dependencies""",
"transitive_sources": """(depset[File]) The Python sources.""",
"direct_pyi_files": """
(depset[File]) Type definition files (usually `.pyi` files)
for the Python modules provided by this target.""",
"transitive_pyi_files": """
(depset[File]) The transitive set of type definition files
(usually `.pyi` files) for the Python modules for this target
and its transitive dependencies.""",
},
)

Expand Down Expand Up @@ -61,6 +68,7 @@ def _py_proto_aspect_impl(target, ctx):
api_deps = [proto_lang_toolchain_info.runtime]

generated_sources = []
generated_stubs = []
proto_info = target[ProtoInfo]
proto_root = proto_info.proto_source_root
if proto_info.direct_sources:
Expand All @@ -72,6 +80,14 @@ def _py_proto_aspect_impl(target, ctx):
name_mapper = lambda name: name.replace("-", "_").replace(".", "/"),
)

# Generate pyi files
generated_stubs = proto_common.declare_generated_files(
actions = ctx.actions,
proto_info = proto_info,
extension = "_pb2.pyi",
name_mapper = lambda name: name.replace("-", "_").replace(".", "/"),
)

# Handles multiple repository and virtual import cases
if proto_root.startswith(ctx.bin_dir.path):
proto_root = proto_root[len(ctx.bin_dir.path) + 1:]
Expand All @@ -84,12 +100,17 @@ def _py_proto_aspect_impl(target, ctx):
else:
proto_root = ctx.workspace_name + "/" + proto_root

additional_args = ctx.actions.args()
if generated_stubs:
additional_args.add(plugin_output, format="--pyi_out=%s")

proto_common.compile(
actions = ctx.actions,
proto_info = proto_info,
proto_lang_toolchain_info = proto_lang_toolchain_info,
generated_files = generated_sources,
generated_files = generated_sources + generated_stubs,
plugin_output = plugin_output,
additional_args = additional_args,
)

# Generated sources == Python sources
Expand All @@ -104,6 +125,13 @@ def _py_proto_aspect_impl(target, ctx):
direct = python_sources,
transitive = [dep.transitive_sources for dep in deps],
)
direct_pyi_files = depset(
direct = generated_stubs,
)
transitive_pyi_files = depset(
direct = generated_stubs,
transitive = [dep.transitive_pyi_files for dep in deps],
)

return [
_PyProtoInfo(
Expand All @@ -119,6 +147,8 @@ def _py_proto_aspect_impl(target, ctx):
),
runfiles_from_proto_deps = runfiles_from_proto_deps,
transitive_sources = transitive_sources,
direct_pyi_files = direct_pyi_files,
transitive_pyi_files = transitive_pyi_files,
),
]

Expand Down Expand Up @@ -150,6 +180,13 @@ def _py_proto_library_rule(ctx):
default_outputs = depset(
transitive = [info.transitive_sources for info in pyproto_infos],
)
direct_pyi_files = []
for info in pyproto_infos:
direct_pyi_files.extend(info.direct_pyi_files.to_list())
transitive_pyi_files = depset(
direct = direct_pyi_files,
transitive = [info.transitive_pyi_files for info in pyproto_infos],
)

return [
DefaultInfo(
Expand All @@ -166,6 +203,8 @@ def _py_proto_library_rule(ctx):
PyInfo(
transitive_sources = default_outputs,
imports = depset(transitive = [info.imports for info in pyproto_infos]),
direct_pyi_files = depset(direct = direct_pyi_files),
transitive_pyi_files = transitive_pyi_files,
# Proto always produces 2- and 3- compatible source files
has_py2_only_sources = False,
has_py3_only_sources = False,
Expand Down
Loading