Skip to content

Commit 1332b1a

Browse files
committed
Fix type issues and implement lazy value support in base metadata provider too
1 parent 87ac7f5 commit 1332b1a

6 files changed

Lines changed: 36 additions & 48 deletions

File tree

libcst/_metadata_dependent.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from contextlib import contextmanager
99
from typing import (
1010
Callable,
11+
Generic,
12+
Union,
1113
cast,
1214
ClassVar,
1315
Collection,
@@ -30,12 +32,12 @@
3032

3133
_T = TypeVar("_T")
3234

33-
_UNDEFINED_DEFAULT = object()
3435

35-
_SENTINEL = object()
36+
class _UNDEFINED_DEFAULT:
37+
pass
3638

3739

38-
class LazyValue:
40+
class LazyValue(Generic[_T]):
3941
"""
4042
The class for implementing a lazy metadata loading mechanism that improves the
4143
performance when retriving expensive metadata (e.g., qualified names). Providers
@@ -46,12 +48,12 @@ class LazyValue:
4648

4749
def __init__(self, callable: Callable[[], _T]) -> None:
4850
self.callable = callable
49-
self.return_value: object = _SENTINEL
51+
self.return_value: Union[_T, Type[_UNDEFINED_DEFAULT]] = _UNDEFINED_DEFAULT
5052

51-
def __call__(self) -> object:
52-
if self.return_value is _SENTINEL:
53+
def __call__(self) -> _T:
54+
if self.return_value is _UNDEFINED_DEFAULT:
5355
self.return_value = self.callable()
54-
return self.return_value
56+
return cast(_T, self.return_value)
5557

5658

5759
class MetadataDependent(ABC):

libcst/codemod/visitors/_apply_type_annotations.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from collections import defaultdict
77
from dataclasses import dataclass
8-
from typing import cast, Collection, Dict, List, Optional, Sequence, Set, Tuple, Union
8+
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
99

1010
import libcst as cst
1111
import libcst.matchers as m
@@ -17,7 +17,7 @@
1717
from libcst.codemod.visitors._gather_imports import GatherImportsVisitor
1818
from libcst.codemod.visitors._imports import ImportItem
1919
from libcst.helpers import get_full_name_for_node
20-
from libcst.metadata import PositionProvider, QualifiedName, QualifiedNameProvider
20+
from libcst.metadata import PositionProvider, QualifiedNameProvider
2121

2222

2323
NameOrAttribute = Union[cst.Name, cst.Attribute]
@@ -48,12 +48,7 @@ def _get_unique_qualified_name(
4848
visitor: m.MatcherDecoratableVisitor, node: cst.CSTNode
4949
) -> str:
5050
name = None
51-
names = [
52-
q.name
53-
for q in cast(
54-
Collection[QualifiedName], visitor.get_metadata(QualifiedNameProvider, node)
55-
)
56-
]
51+
names = [q.name for q in visitor.get_metadata(QualifiedNameProvider, node)]
5752
if len(names) == 0:
5853
# we hit this branch if the stub is directly using a fully
5954
# qualified name, which is not technically valid python but is

libcst/codemod/visitors/_gather_string_annotation_names.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import libcst.matchers as m
1010
from libcst.codemod._context import CodemodContext
1111
from libcst.codemod._visitor import ContextAwareVisitor
12-
from libcst.metadata import MetadataWrapper, QualifiedName, QualifiedNameProvider
12+
from libcst.metadata import MetadataWrapper, QualifiedNameProvider
1313

1414
FUNCS_CONSIDERED_AS_STRING_ANNOTATIONS = {"typing.TypeVar"}
1515

@@ -45,9 +45,7 @@ def leave_Annotation(self, original_node: cst.Annotation) -> None:
4545
self._annotation_stack.pop()
4646

4747
def visit_Call(self, node: cst.Call) -> bool:
48-
qnames = cast(
49-
Collection[QualifiedName], self.get_metadata(QualifiedNameProvider, node)
50-
)
48+
qnames = self.get_metadata(QualifiedNameProvider, node)
5149
if any(qn.name in self._typing_functions for qn in qnames):
5250
self._annotation_stack.append(node)
5351
return True

libcst/metadata/base_provider.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from types import MappingProxyType
88
from typing import (
99
Callable,
10-
cast,
10+
Union,
1111
Generic,
1212
List,
1313
Mapping,
@@ -23,6 +23,7 @@
2323
_T as _MetadataT,
2424
_UNDEFINED_DEFAULT,
2525
MetadataDependent,
26+
LazyValue,
2627
)
2728
from libcst._visitors import CSTVisitor
2829

@@ -36,6 +37,7 @@
3637
# BaseMetadataProvider[int] would be a subtype of BaseMetadataProvider[object], so the
3738
# typevar is covariant.
3839
_ProvidedMetadataT = TypeVar("_ProvidedMetadataT", covariant=True)
40+
MaybeLazyMetadataT = Union[LazyValue[_ProvidedMetadataT], _ProvidedMetadataT]
3941

4042

4143
# We can't use an ABCMeta here, because of metaclass conflicts
@@ -52,16 +54,16 @@ class BaseMetadataProvider(MetadataDependent, Generic[_ProvidedMetadataT]):
5254
#
5355
# N.B. This has some typing variance problems. See `set_metadata` for an
5456
# explanation.
55-
_computed: MutableMapping["CSTNode", _ProvidedMetadataT]
57+
_computed: MutableMapping["CSTNode", MaybeLazyMetadataT]
5658

57-
#: Implement gen_cache to indicate the matadata provider depends on cache from external
59+
#: Implement gen_cache to indicate the metadata provider depends on cache from external
5860
#: system. This function will be called by :class:`~libcst.metadata.FullRepoManager`
5961
#: to compute required cache object per file path.
6062
gen_cache: Optional[Callable[[Path, List[str], int], Mapping[str, object]]] = None
6163

6264
def __init__(self, cache: object = None) -> None:
6365
super().__init__()
64-
self._computed = {}
66+
self._computed: MutableMapping["CSTNode", MaybeLazyMetadataT] = {}
6567
if self.gen_cache and cache is None:
6668
# The metadata provider implementation is responsible to store and use cache.
6769
raise Exception(
@@ -71,7 +73,7 @@ def __init__(self, cache: object = None) -> None:
7173

7274
def _gen(
7375
self, wrapper: "MetadataWrapper"
74-
) -> Mapping["CSTNode", _ProvidedMetadataT]:
76+
) -> Mapping["CSTNode", MaybeLazyMetadataT]:
7577
"""
7678
Resolves and returns metadata mapping for the module in ``wrapper``.
7779
@@ -93,11 +95,7 @@ def _gen_impl(self, module: "Module") -> None:
9395
"""
9496
...
9597

96-
# pyre-ignore[46]: The covariant `value` isn't type-safe because we write it to
97-
# pyre: `self._computed`, however we assume that only one subclass in the MRO chain
98-
# pyre: will ever call `set_metadata`, so it's okay for our purposes. There's no
99-
# pyre: sane way to redesign this API so that it doesn't have this problem.
100-
def set_metadata(self, node: "CSTNode", value: _ProvidedMetadataT) -> None:
98+
def set_metadata(self, node: "CSTNode", value: MaybeLazyMetadataT) -> None:
10199
"""
102100
Record a metadata value ``value`` for ``node``.
103101
"""
@@ -107,7 +105,9 @@ def get_metadata(
107105
self,
108106
key: Type["BaseMetadataProvider[_MetadataT]"],
109107
node: "CSTNode",
110-
default: _MetadataT = _UNDEFINED_DEFAULT,
108+
default: Union[
109+
MaybeLazyMetadataT, Type[_UNDEFINED_DEFAULT]
110+
] = _UNDEFINED_DEFAULT,
111111
) -> _MetadataT:
112112
"""
113113
The same method as :func:`~libcst.MetadataDependent.get_metadata` except
@@ -116,9 +116,12 @@ def get_metadata(
116116
"""
117117
if key is type(self):
118118
if default is not _UNDEFINED_DEFAULT:
119-
return cast(_MetadataT, self._computed.get(node, default))
119+
ret = self._computed.get(node, default)
120120
else:
121-
return cast(_MetadataT, self._computed[node])
121+
ret = self._computed[node]
122+
if isinstance(ret, LazyValue):
123+
return ret()
124+
return ret
122125

123126
return super().get_metadata(key, node, default)
124127

libcst/metadata/name_provider.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import dataclasses
77
from pathlib import Path
8-
from typing import cast, Collection, List, Mapping, Optional, Union
8+
from typing import Collection, List, Mapping, Optional, Union
99

1010
import libcst as cst
1111
from libcst._metadata_dependent import LazyValue, MetadataDependent
@@ -17,10 +17,8 @@
1717
ScopeProvider,
1818
)
1919

20-
_UNDEFINED_DEFAULT = object
2120

22-
23-
class QualifiedNameProvider(BatchableMetadataProvider[_UNDEFINED_DEFAULT]):
21+
class QualifiedNameProvider(BatchableMetadataProvider[Collection[QualifiedName]]):
2422
"""
2523
Compute possible qualified names of a variable CSTNode
2624
(extends `PEP-3155 <https://www.python.org/dev/peps/pep-3155/>`_).
@@ -66,10 +64,7 @@ def has_name(
6664
visitor: MetadataDependent, node: cst.CSTNode, name: Union[str, QualifiedName]
6765
) -> bool:
6866
"""Check if any of qualified name has the str name or :class:`~libcst.metadata.QualifiedName` name."""
69-
qualified_names = cast(
70-
Collection[QualifiedName],
71-
visitor.get_metadata(QualifiedNameProvider, node, set()),
72-
)
67+
qualified_names = visitor.get_metadata(QualifiedNameProvider, node, set())
7368
if isinstance(name, str):
7469
return any(qn.name == name for qn in qualified_names)
7570
else:
@@ -178,10 +173,7 @@ def __init__(
178173
self.provider = provider
179174

180175
def on_visit(self, node: cst.CSTNode) -> bool:
181-
qnames = cast(
182-
Collection[QualifiedName],
183-
self.provider.get_metadata(QualifiedNameProvider, node),
184-
)
176+
qnames = self.provider.get_metadata(QualifiedNameProvider, node)
185177
if qnames is not None:
186178
self.provider.set_metadata(
187179
node,

libcst/metadata/tests/test_name_provider.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pathlib import Path
77
from tempfile import TemporaryDirectory
88
from textwrap import dedent
9-
from typing import cast, Collection, Dict, Mapping, Optional, Set, Tuple
9+
from typing import Collection, Dict, Mapping, Optional, Set, Tuple
1010

1111
import libcst as cst
1212
from libcst import ensure_type
@@ -31,9 +31,7 @@ def __init__(self) -> None:
3131
self.qnames: Dict["CSTNode", Collection[QualifiedName]] = {}
3232

3333
def on_visit(self, node: cst.CSTNode) -> bool:
34-
qname = cast(
35-
Collection[QualifiedName], self.get_metadata(QualifiedNameProvider, node)
36-
)
34+
qname = self.get_metadata(QualifiedNameProvider, node)
3735
self.qnames[node] = qname
3836
return True
3937

0 commit comments

Comments
 (0)