Skip to content

Commit 56f65d7

Browse files
committed
feat: Support merging overload annotations into implementation
Issue-442: #442
1 parent b4b502b commit 56f65d7

File tree

2 files changed

+38
-17
lines changed

2 files changed

+38
-17
lines changed

src/griffe/_internal/expressions.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,22 @@ def is_generator(self) -> bool:
250250
"""Whether this expression is a generator."""
251251
return isinstance(self, ExprSubscript) and self.canonical_name == "Generator"
252252

253+
@staticmethod
254+
def _to_binop(elements: Sequence[Expr], op: str) -> ExprBinOp:
255+
if len(elements) == 2: # noqa: PLR2004
256+
left, right = elements
257+
if isinstance(left, Expr):
258+
left = left.modernize()
259+
if isinstance(right, Expr):
260+
right = right.modernize()
261+
return ExprBinOp(left=left, operator=op, right=right)
262+
263+
left = ExprSubscript._to_binop(elements[:-1], op=op)
264+
right = elements[-1]
265+
if isinstance(right, Expr):
266+
right = right.modernize()
267+
return ExprBinOp(left=left, operator=op, right=right)
268+
253269

254270
@dataclass(eq=True, slots=True)
255271
class ExprAttribute(Expr):
@@ -873,22 +889,6 @@ def iterate(self, *, flat: bool = True) -> Iterator[str | Expr]:
873889
yield from _yield(self.slice, flat=flat, outer_precedence=_OperatorPrecedence.NONE)
874890
yield "]"
875891

876-
@staticmethod
877-
def _to_binop(elements: Sequence[Expr], op: str) -> ExprBinOp:
878-
if len(elements) == 2: # noqa: PLR2004
879-
left, right = elements
880-
if isinstance(left, Expr):
881-
left = left.modernize()
882-
if isinstance(right, Expr):
883-
right = right.modernize()
884-
return ExprBinOp(left=left, operator=op, right=right)
885-
886-
left = ExprSubscript._to_binop(elements[:-1], op=op)
887-
right = elements[-1]
888-
if isinstance(right, Expr):
889-
right = right.modernize()
890-
return ExprBinOp(left=left, operator=op, right=right)
891-
892892
def modernize(self) -> ExprBinOp | ExprSubscript:
893893
if self.canonical_path == "typing.Union":
894894
return self._to_binop(self.slice.elements, op="|") # type: ignore[union-attr]

src/griffe/_internal/merger.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import TYPE_CHECKING
77

88
from griffe._internal.exceptions import AliasResolutionError, CyclicAliasError
9+
from griffe._internal.expressions import Expr
910
from griffe._internal.logger import logger
1011

1112
if TYPE_CHECKING:
@@ -60,10 +61,30 @@ def _merge_stubs_overloads(obj: Module | Class, stubs: Module | Class) -> None:
6061
for function_name, overloads in list(stubs.overloads.items()):
6162
if overloads:
6263
with suppress(KeyError):
63-
obj.get_member(function_name).overloads = overloads
64+
_merge_overload_annotations(obj.get_member(function_name), overloads)
6465
del stubs.overloads[function_name]
6566

6667

68+
def _merge_overload_annotations(function: Function, overloads: list[Function]) -> None:
69+
function.overloads = overloads
70+
for parameter in function.parameters:
71+
if parameter.annotation is None:
72+
annotations = []
73+
seen = set()
74+
for overload in overloads:
75+
with suppress(KeyError):
76+
annotation = overload.parameters[parameter.name].annotation
77+
str_annotation = str(annotation)
78+
if isinstance(annotation, Expr) and str_annotation not in seen:
79+
annotations.append(annotation)
80+
seen.add(str_annotation)
81+
if annotations:
82+
if len(annotations) == 1:
83+
parameter.annotation = annotations[0]
84+
else:
85+
parameter.annotation = Expr._to_binop(annotations, op="|")
86+
87+
6788
def _merge_stubs_members(obj: Module | Class, stubs: Module | Class) -> None:
6889
# Merge imports to later know if objects coming from the stubs were imported.
6990
obj.imports.update(stubs.imports)

0 commit comments

Comments
 (0)